mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-15 09:36:10 +00:00
fix: auth sql store: user is owner policy (#2674)
Some checks failed
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 2s
Installer CI / lint (push) Failing after 4s
Installer CI / smoke-test (push) Has been skipped
Integration Tests / discover-tests (push) Successful in 5s
Vector IO Integration Tests / test-matrix (3.12, remote::chromadb) (push) Failing after 4s
Python Package Build Test / build (3.12) (push) Failing after 7s
Python Package Build Test / build (3.13) (push) Failing after 8s
Vector IO Integration Tests / test-matrix (3.13, inline::faiss) (push) Failing after 12s
Test Llama Stack Build / generate-matrix (push) Successful in 10s
Test External Providers / test-external-providers (venv) (push) Failing after 8s
Vector IO Integration Tests / test-matrix (3.12, inline::milvus) (push) Failing after 14s
Unit Tests / unit-tests (3.13) (push) Failing after 8s
Test Llama Stack Build / build-custom-container-distribution (push) Failing after 10s
Vector IO Integration Tests / test-matrix (3.12, remote::pgvector) (push) Failing after 13s
Test Llama Stack Build / build-ubi9-container-distribution (push) Failing after 11s
Vector IO Integration Tests / test-matrix (3.13, inline::sqlite-vec) (push) Failing after 12s
Update ReadTheDocs / update-readthedocs (push) Failing after 10s
Vector IO Integration Tests / test-matrix (3.12, inline::sqlite-vec) (push) Failing after 15s
Vector IO Integration Tests / test-matrix (3.13, remote::pgvector) (push) Failing after 13s
Test Llama Stack Build / build-single-provider (push) Failing after 13s
Integration Tests / test-matrix (push) Failing after 11s
Vector IO Integration Tests / test-matrix (3.12, inline::faiss) (push) Failing after 17s
Unit Tests / unit-tests (3.12) (push) Failing after 13s
Test Llama Stack Build / build (push) Failing after 6s
Vector IO Integration Tests / test-matrix (3.13, remote::chromadb) (push) Failing after 15s
SqlStore Integration Tests / test-postgres (3.12) (push) Failing after 20s
Vector IO Integration Tests / test-matrix (3.13, inline::milvus) (push) Failing after 17s
SqlStore Integration Tests / test-postgres (3.13) (push) Failing after 26s
Pre-commit / pre-commit (push) Successful in 1m8s
Some checks failed
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 2s
Installer CI / lint (push) Failing after 4s
Installer CI / smoke-test (push) Has been skipped
Integration Tests / discover-tests (push) Successful in 5s
Vector IO Integration Tests / test-matrix (3.12, remote::chromadb) (push) Failing after 4s
Python Package Build Test / build (3.12) (push) Failing after 7s
Python Package Build Test / build (3.13) (push) Failing after 8s
Vector IO Integration Tests / test-matrix (3.13, inline::faiss) (push) Failing after 12s
Test Llama Stack Build / generate-matrix (push) Successful in 10s
Test External Providers / test-external-providers (venv) (push) Failing after 8s
Vector IO Integration Tests / test-matrix (3.12, inline::milvus) (push) Failing after 14s
Unit Tests / unit-tests (3.13) (push) Failing after 8s
Test Llama Stack Build / build-custom-container-distribution (push) Failing after 10s
Vector IO Integration Tests / test-matrix (3.12, remote::pgvector) (push) Failing after 13s
Test Llama Stack Build / build-ubi9-container-distribution (push) Failing after 11s
Vector IO Integration Tests / test-matrix (3.13, inline::sqlite-vec) (push) Failing after 12s
Update ReadTheDocs / update-readthedocs (push) Failing after 10s
Vector IO Integration Tests / test-matrix (3.12, inline::sqlite-vec) (push) Failing after 15s
Vector IO Integration Tests / test-matrix (3.13, remote::pgvector) (push) Failing after 13s
Test Llama Stack Build / build-single-provider (push) Failing after 13s
Integration Tests / test-matrix (push) Failing after 11s
Vector IO Integration Tests / test-matrix (3.12, inline::faiss) (push) Failing after 17s
Unit Tests / unit-tests (3.12) (push) Failing after 13s
Test Llama Stack Build / build (push) Failing after 6s
Vector IO Integration Tests / test-matrix (3.13, remote::chromadb) (push) Failing after 15s
SqlStore Integration Tests / test-postgres (3.12) (push) Failing after 20s
Vector IO Integration Tests / test-matrix (3.13, inline::milvus) (push) Failing after 17s
SqlStore Integration Tests / test-postgres (3.13) (push) Failing after 26s
Pre-commit / pre-commit (push) Successful in 1m8s
# What does this PR do? The current authorized sql store implementation does not respect user.principal (only checks attributes). This PR addresses that. ## Test Plan Added test cases to integration tests.
This commit is contained in:
parent
4cf1952c32
commit
d880c2df0e
5 changed files with 247 additions and 175 deletions
|
@ -81,7 +81,7 @@ def is_action_allowed(
|
|||
if not len(policy):
|
||||
policy = default_policy()
|
||||
|
||||
qualified_resource_id = resource.type + "::" + resource.identifier
|
||||
qualified_resource_id = f"{resource.type}::{resource.identifier}"
|
||||
for rule in policy:
|
||||
if rule.forbid and matches_scope(rule.forbid, action, qualified_resource_id, user.principal):
|
||||
if rule.when:
|
||||
|
|
|
@ -39,22 +39,10 @@ SQL_OPTIMIZED_POLICY = [
|
|||
|
||||
|
||||
class SqlRecord(ProtectedResource):
|
||||
"""Simple ProtectedResource implementation for SQL records."""
|
||||
|
||||
def __init__(self, record_id: str, table_name: str, access_attributes: dict[str, list[str]] | None = None):
|
||||
def __init__(self, record_id: str, table_name: str, owner: User):
|
||||
self.type = f"sql_record::{table_name}"
|
||||
self.identifier = record_id
|
||||
|
||||
if access_attributes:
|
||||
self.owner = User(
|
||||
principal="system",
|
||||
attributes=access_attributes,
|
||||
)
|
||||
else:
|
||||
self.owner = User(
|
||||
principal="system_public",
|
||||
attributes=None,
|
||||
)
|
||||
self.owner = owner
|
||||
|
||||
|
||||
class AuthorizedSqlStore:
|
||||
|
@ -101,22 +89,27 @@ class AuthorizedSqlStore:
|
|||
|
||||
async def create_table(self, table: str, schema: Mapping[str, ColumnType | ColumnDefinition]) -> None:
|
||||
"""Create a table with built-in access control support."""
|
||||
await self.sql_store.add_column_if_not_exists(table, "access_attributes", ColumnType.JSON)
|
||||
|
||||
enhanced_schema = dict(schema)
|
||||
if "access_attributes" not in enhanced_schema:
|
||||
enhanced_schema["access_attributes"] = ColumnType.JSON
|
||||
if "owner_principal" not in enhanced_schema:
|
||||
enhanced_schema["owner_principal"] = ColumnType.STRING
|
||||
|
||||
await self.sql_store.create_table(table, enhanced_schema)
|
||||
await self.sql_store.add_column_if_not_exists(table, "access_attributes", ColumnType.JSON)
|
||||
await self.sql_store.add_column_if_not_exists(table, "owner_principal", ColumnType.STRING)
|
||||
|
||||
async def insert(self, table: str, data: Mapping[str, Any]) -> None:
|
||||
"""Insert a row with automatic access control attribute capture."""
|
||||
enhanced_data = dict(data)
|
||||
|
||||
current_user = get_authenticated_user()
|
||||
if current_user and current_user.attributes:
|
||||
if current_user:
|
||||
enhanced_data["owner_principal"] = current_user.principal
|
||||
enhanced_data["access_attributes"] = current_user.attributes
|
||||
else:
|
||||
enhanced_data["owner_principal"] = None
|
||||
enhanced_data["access_attributes"] = None
|
||||
|
||||
await self.sql_store.insert(table, enhanced_data)
|
||||
|
@ -146,9 +139,12 @@ class AuthorizedSqlStore:
|
|||
|
||||
for row in rows.data:
|
||||
stored_access_attrs = row.get("access_attributes")
|
||||
stored_owner_principal = row.get("owner_principal") or ""
|
||||
|
||||
record_id = row.get("id", "unknown")
|
||||
sql_record = SqlRecord(str(record_id), table, stored_access_attrs)
|
||||
sql_record = SqlRecord(
|
||||
str(record_id), table, User(principal=stored_owner_principal, attributes=stored_access_attrs)
|
||||
)
|
||||
|
||||
if is_action_allowed(policy, Action.READ, sql_record, current_user):
|
||||
filtered_rows.append(row)
|
||||
|
@ -186,8 +182,10 @@ class AuthorizedSqlStore:
|
|||
Only applies SQL filtering for the default policy to ensure correctness.
|
||||
For custom policies, uses conservative filtering to avoid blocking legitimate access.
|
||||
"""
|
||||
current_user = get_authenticated_user()
|
||||
|
||||
if not policy or policy == SQL_OPTIMIZED_POLICY:
|
||||
return self._build_default_policy_where_clause()
|
||||
return self._build_default_policy_where_clause(current_user)
|
||||
else:
|
||||
return self._build_conservative_where_clause()
|
||||
|
||||
|
@ -227,29 +225,27 @@ class AuthorizedSqlStore:
|
|||
|
||||
def _get_public_access_conditions(self) -> list[str]:
|
||||
"""Get the SQL conditions for public access."""
|
||||
# Public records are records that have no owner_principal or access_attributes
|
||||
conditions = ["owner_principal = ''"]
|
||||
if self.database_type == SqlStoreType.postgres:
|
||||
# Postgres stores JSON null as 'null'
|
||||
return ["access_attributes::text = 'null'"]
|
||||
conditions.append("access_attributes::text = 'null'")
|
||||
elif self.database_type == SqlStoreType.sqlite:
|
||||
return ["access_attributes = 'null'"]
|
||||
conditions.append("access_attributes = 'null'")
|
||||
else:
|
||||
raise ValueError(f"Unsupported database type: {self.database_type}")
|
||||
return conditions
|
||||
|
||||
def _build_default_policy_where_clause(self) -> str:
|
||||
def _build_default_policy_where_clause(self, current_user: User | None) -> str:
|
||||
"""Build SQL WHERE clause for the default policy.
|
||||
|
||||
Default policy: permit all actions when user in owners [roles, teams, projects, namespaces]
|
||||
This means user must match ALL attribute categories that exist in the resource.
|
||||
"""
|
||||
current_user = get_authenticated_user()
|
||||
|
||||
base_conditions = self._get_public_access_conditions()
|
||||
if not current_user or not current_user.attributes:
|
||||
# Only allow public records
|
||||
return f"({' OR '.join(base_conditions)})"
|
||||
else:
|
||||
user_attr_conditions = []
|
||||
|
||||
if current_user and current_user.attributes:
|
||||
for attr_key, user_values in current_user.attributes.items():
|
||||
if user_values:
|
||||
value_conditions = []
|
||||
|
|
|
@ -244,16 +244,22 @@ class SqlAlchemySqlStoreImpl(SqlStore):
|
|||
engine = create_async_engine(self.config.engine_str)
|
||||
|
||||
try:
|
||||
inspector = inspect(engine)
|
||||
async with engine.begin() as conn:
|
||||
|
||||
def check_column_exists(sync_conn):
|
||||
inspector = inspect(sync_conn)
|
||||
|
||||
table_names = inspector.get_table_names()
|
||||
if table not in table_names:
|
||||
return
|
||||
return False, False # table doesn't exist, column doesn't exist
|
||||
|
||||
existing_columns = inspector.get_columns(table)
|
||||
column_names = [col["name"] for col in existing_columns]
|
||||
|
||||
if column_name in column_names:
|
||||
return True, column_name in column_names # table exists, column exists or not
|
||||
|
||||
table_exists, column_exists = await conn.run_sync(check_column_exists)
|
||||
if not table_exists or column_exists:
|
||||
return
|
||||
|
||||
sqlalchemy_type = TYPE_MAPPING.get(column_type)
|
||||
|
@ -269,10 +275,10 @@ class SqlAlchemySqlStoreImpl(SqlStore):
|
|||
nullable_clause = "" if nullable else " NOT NULL"
|
||||
add_column_sql = text(f"ALTER TABLE {table} ADD COLUMN {column_name} {compiled_type}{nullable_clause}")
|
||||
|
||||
async with engine.begin() as conn:
|
||||
await conn.execute(add_column_sql)
|
||||
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
# If any error occurs during migration, log it but don't fail
|
||||
# The table creation will handle adding the column
|
||||
logger.error(f"Error adding column {column_name} to table {table}: {e}")
|
||||
pass
|
||||
|
|
|
@ -14,8 +14,7 @@ from llama_stack.distribution.access_control.access_control import default_polic
|
|||
from llama_stack.distribution.datatypes import User
|
||||
from llama_stack.providers.utils.sqlstore.api import ColumnType
|
||||
from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore
|
||||
from llama_stack.providers.utils.sqlstore.sqlalchemy_sqlstore import SqlAlchemySqlStoreImpl
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import PostgresSqlStoreConfig, SqliteSqlStoreConfig
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import PostgresSqlStoreConfig, SqliteSqlStoreConfig, sqlstore_impl
|
||||
|
||||
|
||||
def get_postgres_config():
|
||||
|
@ -30,45 +29,47 @@ def get_postgres_config():
|
|||
|
||||
|
||||
def get_sqlite_config():
|
||||
"""Get SQLite configuration with temporary database."""
|
||||
tmp_file = tempfile.NamedTemporaryFile(suffix=".db", delete=False)
|
||||
tmp_file.close()
|
||||
return SqliteSqlStoreConfig(db_path=tmp_file.name), tmp_file.name
|
||||
"""Get SQLite configuration with temporary file database."""
|
||||
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db")
|
||||
temp_file.close()
|
||||
return SqliteSqlStoreConfig(db_path=temp_file.name)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"backend_config",
|
||||
[
|
||||
# Backend configurations for parametrized tests
|
||||
BACKEND_CONFIGS = [
|
||||
pytest.param(
|
||||
("postgres", get_postgres_config),
|
||||
get_postgres_config,
|
||||
marks=pytest.mark.skipif(
|
||||
not os.environ.get("ENABLE_POSTGRES_TESTS"),
|
||||
reason="PostgreSQL tests require ENABLE_POSTGRES_TESTS environment variable",
|
||||
),
|
||||
id="postgres",
|
||||
),
|
||||
pytest.param(("sqlite", get_sqlite_config), id="sqlite"),
|
||||
],
|
||||
)
|
||||
@patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user")
|
||||
async def test_json_comparison(mock_get_authenticated_user, backend_config):
|
||||
"""Test that JSON column comparisons work correctly for both PostgreSQL and SQLite"""
|
||||
backend_name, config_func = backend_config
|
||||
pytest.param(get_sqlite_config, id="sqlite"),
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def authorized_store(backend_config):
|
||||
"""Set up authorized store with proper cleanup."""
|
||||
config_func = backend_config
|
||||
|
||||
# Handle different config types
|
||||
if backend_name == "postgres":
|
||||
config = config_func()
|
||||
cleanup_path = None
|
||||
else: # sqlite
|
||||
config, cleanup_path = config_func()
|
||||
|
||||
try:
|
||||
base_sqlstore = SqlAlchemySqlStoreImpl(config)
|
||||
base_sqlstore = sqlstore_impl(config)
|
||||
authorized_store = AuthorizedSqlStore(base_sqlstore)
|
||||
|
||||
# Create test table
|
||||
table_name = f"test_json_comparison_{backend_name}"
|
||||
yield authorized_store
|
||||
|
||||
if hasattr(config, "db_path"):
|
||||
try:
|
||||
os.unlink(config.db_path)
|
||||
except (OSError, FileNotFoundError):
|
||||
pass
|
||||
|
||||
|
||||
async def create_test_table(authorized_store, table_name):
|
||||
"""Create a test table with standard schema."""
|
||||
await authorized_store.create_table(
|
||||
table=table_name,
|
||||
schema={
|
||||
|
@ -77,6 +78,27 @@ async def test_json_comparison(mock_get_authenticated_user, backend_config):
|
|||
},
|
||||
)
|
||||
|
||||
|
||||
async def cleanup_records(sql_store, table_name, record_ids):
|
||||
"""Clean up test records."""
|
||||
for record_id in record_ids:
|
||||
try:
|
||||
await sql_store.delete(table_name, {"id": record_id})
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("backend_config", BACKEND_CONFIGS)
|
||||
@patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user")
|
||||
async def test_authorized_store_attributes(mock_get_authenticated_user, authorized_store, request):
|
||||
"""Test that JSON column comparisons work correctly for both PostgreSQL and SQLite"""
|
||||
backend_name = request.node.callspec.id
|
||||
|
||||
# Create test table
|
||||
table_name = f"test_json_comparison_{backend_name}"
|
||||
await create_test_table(authorized_store, table_name)
|
||||
|
||||
try:
|
||||
# Test with no authenticated user (should handle JSON null comparison)
|
||||
mock_get_authenticated_user.return_value = None
|
||||
|
@ -158,16 +180,62 @@ async def test_json_comparison(mock_get_authenticated_user, backend_config):
|
|||
|
||||
finally:
|
||||
# Clean up records
|
||||
for record_id in ["1", "2", "3", "4", "5", "6"]:
|
||||
await cleanup_records(authorized_store.sql_store, table_name, ["1", "2", "3", "4", "5", "6"])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("backend_config", BACKEND_CONFIGS)
|
||||
@patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user")
|
||||
async def test_user_ownership_policy(mock_get_authenticated_user, authorized_store, request):
|
||||
"""Test that 'user is owner' policies work correctly with record ownership"""
|
||||
from llama_stack.distribution.access_control.datatypes import AccessRule, Action, Scope
|
||||
|
||||
backend_name = request.node.callspec.id
|
||||
|
||||
# Create test table
|
||||
table_name = f"test_ownership_{backend_name}"
|
||||
await create_test_table(authorized_store, table_name)
|
||||
|
||||
try:
|
||||
await base_sqlstore.delete(table_name, {"id": record_id})
|
||||
except Exception:
|
||||
pass
|
||||
# Test with first user who creates records
|
||||
user1 = User("user1", {"roles": ["admin"]})
|
||||
mock_get_authenticated_user.return_value = user1
|
||||
|
||||
# Insert a record owned by user1
|
||||
await authorized_store.insert(table_name, {"id": "1", "data": "user1_data"})
|
||||
|
||||
# Test with second user
|
||||
user2 = User("user2", {"roles": ["user"]})
|
||||
mock_get_authenticated_user.return_value = user2
|
||||
|
||||
# Insert a record owned by user2
|
||||
await authorized_store.insert(table_name, {"id": "2", "data": "user2_data"})
|
||||
|
||||
# Create a policy that only allows access when user is the owner
|
||||
owner_only_policy = [
|
||||
AccessRule(
|
||||
permit=Scope(actions=[Action.READ]),
|
||||
when=["user is owner"],
|
||||
),
|
||||
]
|
||||
|
||||
# Test user1 access - should only see their own record
|
||||
mock_get_authenticated_user.return_value = user1
|
||||
result = await authorized_store.fetch_all(table_name, policy=owner_only_policy)
|
||||
assert len(result.data) == 1, f"Expected user1 to see 1 record, got {len(result.data)}"
|
||||
assert result.data[0]["id"] == "1", f"Expected user1's record, got {result.data[0]['id']}"
|
||||
|
||||
# Test user2 access - should only see their own record
|
||||
mock_get_authenticated_user.return_value = user2
|
||||
result = await authorized_store.fetch_all(table_name, policy=owner_only_policy)
|
||||
assert len(result.data) == 1, f"Expected user2 to see 1 record, got {len(result.data)}"
|
||||
assert result.data[0]["id"] == "2", f"Expected user2's record, got {result.data[0]['id']}"
|
||||
|
||||
# Test with anonymous user - should see no records
|
||||
mock_get_authenticated_user.return_value = None
|
||||
result = await authorized_store.fetch_all(table_name, policy=owner_only_policy)
|
||||
assert len(result.data) == 0, f"Expected anonymous user to see 0 records, got {len(result.data)}"
|
||||
|
||||
finally:
|
||||
# Clean up temporary SQLite database file if needed
|
||||
if cleanup_path:
|
||||
try:
|
||||
os.unlink(cleanup_path)
|
||||
except OSError:
|
||||
pass
|
||||
# Clean up records
|
||||
await cleanup_records(authorized_store.sql_store, table_name, ["1", "2"])
|
||||
|
|
|
@ -153,7 +153,9 @@ async def test_sql_policy_consistency(mock_get_authenticated_user):
|
|||
policy_ids = set()
|
||||
for scenario in test_scenarios:
|
||||
sql_record = SqlRecord(
|
||||
record_id=scenario["id"], table_name="resources", access_attributes=scenario["access_attributes"]
|
||||
record_id=scenario["id"],
|
||||
table_name="resources",
|
||||
owner=User(principal="test-user", attributes=scenario["access_attributes"]),
|
||||
)
|
||||
|
||||
if is_action_allowed(policy, Action.READ, sql_record, user):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue