mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-15 09:36:10 +00:00
fix: auth sql store: user is owner policy (#2674)
Some checks failed
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 2s
Installer CI / lint (push) Failing after 4s
Installer CI / smoke-test (push) Has been skipped
Integration Tests / discover-tests (push) Successful in 5s
Vector IO Integration Tests / test-matrix (3.12, remote::chromadb) (push) Failing after 4s
Python Package Build Test / build (3.12) (push) Failing after 7s
Python Package Build Test / build (3.13) (push) Failing after 8s
Vector IO Integration Tests / test-matrix (3.13, inline::faiss) (push) Failing after 12s
Test Llama Stack Build / generate-matrix (push) Successful in 10s
Test External Providers / test-external-providers (venv) (push) Failing after 8s
Vector IO Integration Tests / test-matrix (3.12, inline::milvus) (push) Failing after 14s
Unit Tests / unit-tests (3.13) (push) Failing after 8s
Test Llama Stack Build / build-custom-container-distribution (push) Failing after 10s
Vector IO Integration Tests / test-matrix (3.12, remote::pgvector) (push) Failing after 13s
Test Llama Stack Build / build-ubi9-container-distribution (push) Failing after 11s
Vector IO Integration Tests / test-matrix (3.13, inline::sqlite-vec) (push) Failing after 12s
Update ReadTheDocs / update-readthedocs (push) Failing after 10s
Vector IO Integration Tests / test-matrix (3.12, inline::sqlite-vec) (push) Failing after 15s
Vector IO Integration Tests / test-matrix (3.13, remote::pgvector) (push) Failing after 13s
Test Llama Stack Build / build-single-provider (push) Failing after 13s
Integration Tests / test-matrix (push) Failing after 11s
Vector IO Integration Tests / test-matrix (3.12, inline::faiss) (push) Failing after 17s
Unit Tests / unit-tests (3.12) (push) Failing after 13s
Test Llama Stack Build / build (push) Failing after 6s
Vector IO Integration Tests / test-matrix (3.13, remote::chromadb) (push) Failing after 15s
SqlStore Integration Tests / test-postgres (3.12) (push) Failing after 20s
Vector IO Integration Tests / test-matrix (3.13, inline::milvus) (push) Failing after 17s
SqlStore Integration Tests / test-postgres (3.13) (push) Failing after 26s
Pre-commit / pre-commit (push) Successful in 1m8s
Some checks failed
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 2s
Installer CI / lint (push) Failing after 4s
Installer CI / smoke-test (push) Has been skipped
Integration Tests / discover-tests (push) Successful in 5s
Vector IO Integration Tests / test-matrix (3.12, remote::chromadb) (push) Failing after 4s
Python Package Build Test / build (3.12) (push) Failing after 7s
Python Package Build Test / build (3.13) (push) Failing after 8s
Vector IO Integration Tests / test-matrix (3.13, inline::faiss) (push) Failing after 12s
Test Llama Stack Build / generate-matrix (push) Successful in 10s
Test External Providers / test-external-providers (venv) (push) Failing after 8s
Vector IO Integration Tests / test-matrix (3.12, inline::milvus) (push) Failing after 14s
Unit Tests / unit-tests (3.13) (push) Failing after 8s
Test Llama Stack Build / build-custom-container-distribution (push) Failing after 10s
Vector IO Integration Tests / test-matrix (3.12, remote::pgvector) (push) Failing after 13s
Test Llama Stack Build / build-ubi9-container-distribution (push) Failing after 11s
Vector IO Integration Tests / test-matrix (3.13, inline::sqlite-vec) (push) Failing after 12s
Update ReadTheDocs / update-readthedocs (push) Failing after 10s
Vector IO Integration Tests / test-matrix (3.12, inline::sqlite-vec) (push) Failing after 15s
Vector IO Integration Tests / test-matrix (3.13, remote::pgvector) (push) Failing after 13s
Test Llama Stack Build / build-single-provider (push) Failing after 13s
Integration Tests / test-matrix (push) Failing after 11s
Vector IO Integration Tests / test-matrix (3.12, inline::faiss) (push) Failing after 17s
Unit Tests / unit-tests (3.12) (push) Failing after 13s
Test Llama Stack Build / build (push) Failing after 6s
Vector IO Integration Tests / test-matrix (3.13, remote::chromadb) (push) Failing after 15s
SqlStore Integration Tests / test-postgres (3.12) (push) Failing after 20s
Vector IO Integration Tests / test-matrix (3.13, inline::milvus) (push) Failing after 17s
SqlStore Integration Tests / test-postgres (3.13) (push) Failing after 26s
Pre-commit / pre-commit (push) Successful in 1m8s
# What does this PR do? The current authorized sql store implementation does not respect user.principal (only checks attributes). This PR addresses that. ## Test Plan Added test cases to integration tests.
This commit is contained in:
parent
4cf1952c32
commit
d880c2df0e
5 changed files with 247 additions and 175 deletions
|
@ -81,7 +81,7 @@ def is_action_allowed(
|
||||||
if not len(policy):
|
if not len(policy):
|
||||||
policy = default_policy()
|
policy = default_policy()
|
||||||
|
|
||||||
qualified_resource_id = resource.type + "::" + resource.identifier
|
qualified_resource_id = f"{resource.type}::{resource.identifier}"
|
||||||
for rule in policy:
|
for rule in policy:
|
||||||
if rule.forbid and matches_scope(rule.forbid, action, qualified_resource_id, user.principal):
|
if rule.forbid and matches_scope(rule.forbid, action, qualified_resource_id, user.principal):
|
||||||
if rule.when:
|
if rule.when:
|
||||||
|
|
|
@ -39,22 +39,10 @@ SQL_OPTIMIZED_POLICY = [
|
||||||
|
|
||||||
|
|
||||||
class SqlRecord(ProtectedResource):
|
class SqlRecord(ProtectedResource):
|
||||||
"""Simple ProtectedResource implementation for SQL records."""
|
def __init__(self, record_id: str, table_name: str, owner: User):
|
||||||
|
|
||||||
def __init__(self, record_id: str, table_name: str, access_attributes: dict[str, list[str]] | None = None):
|
|
||||||
self.type = f"sql_record::{table_name}"
|
self.type = f"sql_record::{table_name}"
|
||||||
self.identifier = record_id
|
self.identifier = record_id
|
||||||
|
self.owner = owner
|
||||||
if access_attributes:
|
|
||||||
self.owner = User(
|
|
||||||
principal="system",
|
|
||||||
attributes=access_attributes,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.owner = User(
|
|
||||||
principal="system_public",
|
|
||||||
attributes=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class AuthorizedSqlStore:
|
class AuthorizedSqlStore:
|
||||||
|
@ -101,22 +89,27 @@ class AuthorizedSqlStore:
|
||||||
|
|
||||||
async def create_table(self, table: str, schema: Mapping[str, ColumnType | ColumnDefinition]) -> None:
|
async def create_table(self, table: str, schema: Mapping[str, ColumnType | ColumnDefinition]) -> None:
|
||||||
"""Create a table with built-in access control support."""
|
"""Create a table with built-in access control support."""
|
||||||
await self.sql_store.add_column_if_not_exists(table, "access_attributes", ColumnType.JSON)
|
|
||||||
|
|
||||||
enhanced_schema = dict(schema)
|
enhanced_schema = dict(schema)
|
||||||
if "access_attributes" not in enhanced_schema:
|
if "access_attributes" not in enhanced_schema:
|
||||||
enhanced_schema["access_attributes"] = ColumnType.JSON
|
enhanced_schema["access_attributes"] = ColumnType.JSON
|
||||||
|
if "owner_principal" not in enhanced_schema:
|
||||||
|
enhanced_schema["owner_principal"] = ColumnType.STRING
|
||||||
|
|
||||||
await self.sql_store.create_table(table, enhanced_schema)
|
await self.sql_store.create_table(table, enhanced_schema)
|
||||||
|
await self.sql_store.add_column_if_not_exists(table, "access_attributes", ColumnType.JSON)
|
||||||
|
await self.sql_store.add_column_if_not_exists(table, "owner_principal", ColumnType.STRING)
|
||||||
|
|
||||||
async def insert(self, table: str, data: Mapping[str, Any]) -> None:
|
async def insert(self, table: str, data: Mapping[str, Any]) -> None:
|
||||||
"""Insert a row with automatic access control attribute capture."""
|
"""Insert a row with automatic access control attribute capture."""
|
||||||
enhanced_data = dict(data)
|
enhanced_data = dict(data)
|
||||||
|
|
||||||
current_user = get_authenticated_user()
|
current_user = get_authenticated_user()
|
||||||
if current_user and current_user.attributes:
|
if current_user:
|
||||||
|
enhanced_data["owner_principal"] = current_user.principal
|
||||||
enhanced_data["access_attributes"] = current_user.attributes
|
enhanced_data["access_attributes"] = current_user.attributes
|
||||||
else:
|
else:
|
||||||
|
enhanced_data["owner_principal"] = None
|
||||||
enhanced_data["access_attributes"] = None
|
enhanced_data["access_attributes"] = None
|
||||||
|
|
||||||
await self.sql_store.insert(table, enhanced_data)
|
await self.sql_store.insert(table, enhanced_data)
|
||||||
|
@ -146,9 +139,12 @@ class AuthorizedSqlStore:
|
||||||
|
|
||||||
for row in rows.data:
|
for row in rows.data:
|
||||||
stored_access_attrs = row.get("access_attributes")
|
stored_access_attrs = row.get("access_attributes")
|
||||||
|
stored_owner_principal = row.get("owner_principal") or ""
|
||||||
|
|
||||||
record_id = row.get("id", "unknown")
|
record_id = row.get("id", "unknown")
|
||||||
sql_record = SqlRecord(str(record_id), table, stored_access_attrs)
|
sql_record = SqlRecord(
|
||||||
|
str(record_id), table, User(principal=stored_owner_principal, attributes=stored_access_attrs)
|
||||||
|
)
|
||||||
|
|
||||||
if is_action_allowed(policy, Action.READ, sql_record, current_user):
|
if is_action_allowed(policy, Action.READ, sql_record, current_user):
|
||||||
filtered_rows.append(row)
|
filtered_rows.append(row)
|
||||||
|
@ -186,8 +182,10 @@ class AuthorizedSqlStore:
|
||||||
Only applies SQL filtering for the default policy to ensure correctness.
|
Only applies SQL filtering for the default policy to ensure correctness.
|
||||||
For custom policies, uses conservative filtering to avoid blocking legitimate access.
|
For custom policies, uses conservative filtering to avoid blocking legitimate access.
|
||||||
"""
|
"""
|
||||||
|
current_user = get_authenticated_user()
|
||||||
|
|
||||||
if not policy or policy == SQL_OPTIMIZED_POLICY:
|
if not policy or policy == SQL_OPTIMIZED_POLICY:
|
||||||
return self._build_default_policy_where_clause()
|
return self._build_default_policy_where_clause(current_user)
|
||||||
else:
|
else:
|
||||||
return self._build_conservative_where_clause()
|
return self._build_conservative_where_clause()
|
||||||
|
|
||||||
|
@ -227,29 +225,27 @@ class AuthorizedSqlStore:
|
||||||
|
|
||||||
def _get_public_access_conditions(self) -> list[str]:
|
def _get_public_access_conditions(self) -> list[str]:
|
||||||
"""Get the SQL conditions for public access."""
|
"""Get the SQL conditions for public access."""
|
||||||
|
# Public records are records that have no owner_principal or access_attributes
|
||||||
|
conditions = ["owner_principal = ''"]
|
||||||
if self.database_type == SqlStoreType.postgres:
|
if self.database_type == SqlStoreType.postgres:
|
||||||
# Postgres stores JSON null as 'null'
|
# Postgres stores JSON null as 'null'
|
||||||
return ["access_attributes::text = 'null'"]
|
conditions.append("access_attributes::text = 'null'")
|
||||||
elif self.database_type == SqlStoreType.sqlite:
|
elif self.database_type == SqlStoreType.sqlite:
|
||||||
return ["access_attributes = 'null'"]
|
conditions.append("access_attributes = 'null'")
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported database type: {self.database_type}")
|
raise ValueError(f"Unsupported database type: {self.database_type}")
|
||||||
|
return conditions
|
||||||
|
|
||||||
def _build_default_policy_where_clause(self) -> str:
|
def _build_default_policy_where_clause(self, current_user: User | None) -> str:
|
||||||
"""Build SQL WHERE clause for the default policy.
|
"""Build SQL WHERE clause for the default policy.
|
||||||
|
|
||||||
Default policy: permit all actions when user in owners [roles, teams, projects, namespaces]
|
Default policy: permit all actions when user in owners [roles, teams, projects, namespaces]
|
||||||
This means user must match ALL attribute categories that exist in the resource.
|
This means user must match ALL attribute categories that exist in the resource.
|
||||||
"""
|
"""
|
||||||
current_user = get_authenticated_user()
|
|
||||||
|
|
||||||
base_conditions = self._get_public_access_conditions()
|
base_conditions = self._get_public_access_conditions()
|
||||||
if not current_user or not current_user.attributes:
|
user_attr_conditions = []
|
||||||
# Only allow public records
|
|
||||||
return f"({' OR '.join(base_conditions)})"
|
|
||||||
else:
|
|
||||||
user_attr_conditions = []
|
|
||||||
|
|
||||||
|
if current_user and current_user.attributes:
|
||||||
for attr_key, user_values in current_user.attributes.items():
|
for attr_key, user_values in current_user.attributes.items():
|
||||||
if user_values:
|
if user_values:
|
||||||
value_conditions = []
|
value_conditions = []
|
||||||
|
@ -269,7 +265,7 @@ class AuthorizedSqlStore:
|
||||||
all_requirements_met = f"({' AND '.join(user_attr_conditions)})"
|
all_requirements_met = f"({' AND '.join(user_attr_conditions)})"
|
||||||
base_conditions.append(all_requirements_met)
|
base_conditions.append(all_requirements_met)
|
||||||
|
|
||||||
return f"({' OR '.join(base_conditions)})"
|
return f"({' OR '.join(base_conditions)})"
|
||||||
|
|
||||||
def _build_conservative_where_clause(self) -> str:
|
def _build_conservative_where_clause(self) -> str:
|
||||||
"""Conservative SQL filtering for custom policies.
|
"""Conservative SQL filtering for custom policies.
|
||||||
|
|
|
@ -244,35 +244,41 @@ class SqlAlchemySqlStoreImpl(SqlStore):
|
||||||
engine = create_async_engine(self.config.engine_str)
|
engine = create_async_engine(self.config.engine_str)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
inspector = inspect(engine)
|
|
||||||
|
|
||||||
table_names = inspector.get_table_names()
|
|
||||||
if table not in table_names:
|
|
||||||
return
|
|
||||||
|
|
||||||
existing_columns = inspector.get_columns(table)
|
|
||||||
column_names = [col["name"] for col in existing_columns]
|
|
||||||
|
|
||||||
if column_name in column_names:
|
|
||||||
return
|
|
||||||
|
|
||||||
sqlalchemy_type = TYPE_MAPPING.get(column_type)
|
|
||||||
if not sqlalchemy_type:
|
|
||||||
raise ValueError(f"Unsupported column type '{column_type}' for column '{column_name}'.")
|
|
||||||
|
|
||||||
# Create the ALTER TABLE statement
|
|
||||||
# Note: We need to get the dialect-specific type name
|
|
||||||
dialect = engine.dialect
|
|
||||||
type_impl = sqlalchemy_type()
|
|
||||||
compiled_type = type_impl.compile(dialect=dialect)
|
|
||||||
|
|
||||||
nullable_clause = "" if nullable else " NOT NULL"
|
|
||||||
add_column_sql = text(f"ALTER TABLE {table} ADD COLUMN {column_name} {compiled_type}{nullable_clause}")
|
|
||||||
|
|
||||||
async with engine.begin() as conn:
|
async with engine.begin() as conn:
|
||||||
|
|
||||||
|
def check_column_exists(sync_conn):
|
||||||
|
inspector = inspect(sync_conn)
|
||||||
|
|
||||||
|
table_names = inspector.get_table_names()
|
||||||
|
if table not in table_names:
|
||||||
|
return False, False # table doesn't exist, column doesn't exist
|
||||||
|
|
||||||
|
existing_columns = inspector.get_columns(table)
|
||||||
|
column_names = [col["name"] for col in existing_columns]
|
||||||
|
|
||||||
|
return True, column_name in column_names # table exists, column exists or not
|
||||||
|
|
||||||
|
table_exists, column_exists = await conn.run_sync(check_column_exists)
|
||||||
|
if not table_exists or column_exists:
|
||||||
|
return
|
||||||
|
|
||||||
|
sqlalchemy_type = TYPE_MAPPING.get(column_type)
|
||||||
|
if not sqlalchemy_type:
|
||||||
|
raise ValueError(f"Unsupported column type '{column_type}' for column '{column_name}'.")
|
||||||
|
|
||||||
|
# Create the ALTER TABLE statement
|
||||||
|
# Note: We need to get the dialect-specific type name
|
||||||
|
dialect = engine.dialect
|
||||||
|
type_impl = sqlalchemy_type()
|
||||||
|
compiled_type = type_impl.compile(dialect=dialect)
|
||||||
|
|
||||||
|
nullable_clause = "" if nullable else " NOT NULL"
|
||||||
|
add_column_sql = text(f"ALTER TABLE {table} ADD COLUMN {column_name} {compiled_type}{nullable_clause}")
|
||||||
|
|
||||||
await conn.execute(add_column_sql)
|
await conn.execute(add_column_sql)
|
||||||
|
|
||||||
except Exception:
|
except Exception as e:
|
||||||
# If any error occurs during migration, log it but don't fail
|
# If any error occurs during migration, log it but don't fail
|
||||||
# The table creation will handle adding the column
|
# The table creation will handle adding the column
|
||||||
|
logger.error(f"Error adding column {column_name} to table {table}: {e}")
|
||||||
pass
|
pass
|
||||||
|
|
|
@ -14,8 +14,7 @@ from llama_stack.distribution.access_control.access_control import default_polic
|
||||||
from llama_stack.distribution.datatypes import User
|
from llama_stack.distribution.datatypes import User
|
||||||
from llama_stack.providers.utils.sqlstore.api import ColumnType
|
from llama_stack.providers.utils.sqlstore.api import ColumnType
|
||||||
from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore
|
from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore
|
||||||
from llama_stack.providers.utils.sqlstore.sqlalchemy_sqlstore import SqlAlchemySqlStoreImpl
|
from llama_stack.providers.utils.sqlstore.sqlstore import PostgresSqlStoreConfig, SqliteSqlStoreConfig, sqlstore_impl
|
||||||
from llama_stack.providers.utils.sqlstore.sqlstore import PostgresSqlStoreConfig, SqliteSqlStoreConfig
|
|
||||||
|
|
||||||
|
|
||||||
def get_postgres_config():
|
def get_postgres_config():
|
||||||
|
@ -30,144 +29,213 @@ def get_postgres_config():
|
||||||
|
|
||||||
|
|
||||||
def get_sqlite_config():
|
def get_sqlite_config():
|
||||||
"""Get SQLite configuration with temporary database."""
|
"""Get SQLite configuration with temporary file database."""
|
||||||
tmp_file = tempfile.NamedTemporaryFile(suffix=".db", delete=False)
|
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db")
|
||||||
tmp_file.close()
|
temp_file.close()
|
||||||
return SqliteSqlStoreConfig(db_path=tmp_file.name), tmp_file.name
|
return SqliteSqlStoreConfig(db_path=temp_file.name)
|
||||||
|
|
||||||
|
|
||||||
|
# Backend configurations for parametrized tests
|
||||||
|
BACKEND_CONFIGS = [
|
||||||
|
pytest.param(
|
||||||
|
get_postgres_config,
|
||||||
|
marks=pytest.mark.skipif(
|
||||||
|
not os.environ.get("ENABLE_POSTGRES_TESTS"),
|
||||||
|
reason="PostgreSQL tests require ENABLE_POSTGRES_TESTS environment variable",
|
||||||
|
),
|
||||||
|
id="postgres",
|
||||||
|
),
|
||||||
|
pytest.param(get_sqlite_config, id="sqlite"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def authorized_store(backend_config):
|
||||||
|
"""Set up authorized store with proper cleanup."""
|
||||||
|
config_func = backend_config
|
||||||
|
|
||||||
|
config = config_func()
|
||||||
|
|
||||||
|
base_sqlstore = sqlstore_impl(config)
|
||||||
|
authorized_store = AuthorizedSqlStore(base_sqlstore)
|
||||||
|
|
||||||
|
yield authorized_store
|
||||||
|
|
||||||
|
if hasattr(config, "db_path"):
|
||||||
|
try:
|
||||||
|
os.unlink(config.db_path)
|
||||||
|
except (OSError, FileNotFoundError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
async def create_test_table(authorized_store, table_name):
|
||||||
|
"""Create a test table with standard schema."""
|
||||||
|
await authorized_store.create_table(
|
||||||
|
table=table_name,
|
||||||
|
schema={
|
||||||
|
"id": ColumnType.STRING,
|
||||||
|
"data": ColumnType.STRING,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def cleanup_records(sql_store, table_name, record_ids):
|
||||||
|
"""Clean up test records."""
|
||||||
|
for record_id in record_ids:
|
||||||
|
try:
|
||||||
|
await sql_store.delete(table_name, {"id": record_id})
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize("backend_config", BACKEND_CONFIGS)
|
||||||
"backend_config",
|
|
||||||
[
|
|
||||||
pytest.param(
|
|
||||||
("postgres", get_postgres_config),
|
|
||||||
marks=pytest.mark.skipif(
|
|
||||||
not os.environ.get("ENABLE_POSTGRES_TESTS"),
|
|
||||||
reason="PostgreSQL tests require ENABLE_POSTGRES_TESTS environment variable",
|
|
||||||
),
|
|
||||||
id="postgres",
|
|
||||||
),
|
|
||||||
pytest.param(("sqlite", get_sqlite_config), id="sqlite"),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
@patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user")
|
@patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user")
|
||||||
async def test_json_comparison(mock_get_authenticated_user, backend_config):
|
async def test_authorized_store_attributes(mock_get_authenticated_user, authorized_store, request):
|
||||||
"""Test that JSON column comparisons work correctly for both PostgreSQL and SQLite"""
|
"""Test that JSON column comparisons work correctly for both PostgreSQL and SQLite"""
|
||||||
backend_name, config_func = backend_config
|
backend_name = request.node.callspec.id
|
||||||
|
|
||||||
# Handle different config types
|
# Create test table
|
||||||
if backend_name == "postgres":
|
table_name = f"test_json_comparison_{backend_name}"
|
||||||
config = config_func()
|
await create_test_table(authorized_store, table_name)
|
||||||
cleanup_path = None
|
|
||||||
else: # sqlite
|
|
||||||
config, cleanup_path = config_func()
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
base_sqlstore = SqlAlchemySqlStoreImpl(config)
|
# Test with no authenticated user (should handle JSON null comparison)
|
||||||
authorized_store = AuthorizedSqlStore(base_sqlstore)
|
mock_get_authenticated_user.return_value = None
|
||||||
|
|
||||||
# Create test table
|
# Insert some test data
|
||||||
table_name = f"test_json_comparison_{backend_name}"
|
await authorized_store.insert(table_name, {"id": "1", "data": "public_data"})
|
||||||
await authorized_store.create_table(
|
|
||||||
table=table_name,
|
# Test fetching with no user - should not error on JSON comparison
|
||||||
schema={
|
result = await authorized_store.fetch_all(table_name, policy=default_policy())
|
||||||
"id": ColumnType.STRING,
|
assert len(result.data) == 1
|
||||||
"data": ColumnType.STRING,
|
assert result.data[0]["id"] == "1"
|
||||||
},
|
assert result.data[0]["access_attributes"] is None
|
||||||
|
|
||||||
|
# Test with authenticated user
|
||||||
|
test_user = User("test-user", {"roles": ["admin"]})
|
||||||
|
mock_get_authenticated_user.return_value = test_user
|
||||||
|
|
||||||
|
# Insert data with user attributes
|
||||||
|
await authorized_store.insert(table_name, {"id": "2", "data": "admin_data"})
|
||||||
|
|
||||||
|
# Fetch all - admin should see both
|
||||||
|
result = await authorized_store.fetch_all(table_name, policy=default_policy())
|
||||||
|
assert len(result.data) == 2
|
||||||
|
|
||||||
|
# Test with non-admin user
|
||||||
|
regular_user = User("regular-user", {"roles": ["user"]})
|
||||||
|
mock_get_authenticated_user.return_value = regular_user
|
||||||
|
|
||||||
|
# Should only see public record
|
||||||
|
result = await authorized_store.fetch_all(table_name, policy=default_policy())
|
||||||
|
assert len(result.data) == 1
|
||||||
|
assert result.data[0]["id"] == "1"
|
||||||
|
|
||||||
|
# Test the category missing branch: user with multiple attributes
|
||||||
|
multi_user = User("multi-user", {"roles": ["admin"], "teams": ["dev"]})
|
||||||
|
mock_get_authenticated_user.return_value = multi_user
|
||||||
|
|
||||||
|
# Insert record with multi-user (has both roles and teams)
|
||||||
|
await authorized_store.insert(table_name, {"id": "3", "data": "multi_user_data"})
|
||||||
|
|
||||||
|
# Test different user types to create records with different attribute patterns
|
||||||
|
# Record with only roles (teams category will be missing)
|
||||||
|
roles_only_user = User("roles-user", {"roles": ["admin"]})
|
||||||
|
mock_get_authenticated_user.return_value = roles_only_user
|
||||||
|
await authorized_store.insert(table_name, {"id": "4", "data": "roles_only_data"})
|
||||||
|
|
||||||
|
# Record with only teams (roles category will be missing)
|
||||||
|
teams_only_user = User("teams-user", {"teams": ["dev"]})
|
||||||
|
mock_get_authenticated_user.return_value = teams_only_user
|
||||||
|
await authorized_store.insert(table_name, {"id": "5", "data": "teams_only_data"})
|
||||||
|
|
||||||
|
# Record with different roles/teams (shouldn't match our test user)
|
||||||
|
different_user = User("different-user", {"roles": ["user"], "teams": ["qa"]})
|
||||||
|
mock_get_authenticated_user.return_value = different_user
|
||||||
|
await authorized_store.insert(table_name, {"id": "6", "data": "different_user_data"})
|
||||||
|
|
||||||
|
# Now test with the multi-user who has both roles=admin and teams=dev
|
||||||
|
mock_get_authenticated_user.return_value = multi_user
|
||||||
|
result = await authorized_store.fetch_all(table_name, policy=default_policy())
|
||||||
|
|
||||||
|
# Should see:
|
||||||
|
# - public record (1) - no access_attributes
|
||||||
|
# - admin record (2) - user matches roles=admin, teams missing (allowed)
|
||||||
|
# - multi_user record (3) - user matches both roles=admin and teams=dev
|
||||||
|
# - roles_only record (4) - user matches roles=admin, teams missing (allowed)
|
||||||
|
# - teams_only record (5) - user matches teams=dev, roles missing (allowed)
|
||||||
|
# Should NOT see:
|
||||||
|
# - different_user record (6) - user doesn't match roles=user or teams=qa
|
||||||
|
expected_ids = {"1", "2", "3", "4", "5"}
|
||||||
|
actual_ids = {record["id"] for record in result.data}
|
||||||
|
assert actual_ids == expected_ids, f"Expected to see records {expected_ids} but got {actual_ids}"
|
||||||
|
|
||||||
|
# Verify the category missing logic specifically
|
||||||
|
# Records 4 and 5 test the "category missing" branch where one attribute category is missing
|
||||||
|
category_test_ids = {record["id"] for record in result.data if record["id"] in ["4", "5"]}
|
||||||
|
assert category_test_ids == {"4", "5"}, (
|
||||||
|
f"Category missing logic failed: expected 4,5 but got {category_test_ids}"
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
finally:
|
||||||
# Test with no authenticated user (should handle JSON null comparison)
|
# Clean up records
|
||||||
mock_get_authenticated_user.return_value = None
|
await cleanup_records(authorized_store.sql_store, table_name, ["1", "2", "3", "4", "5", "6"])
|
||||||
|
|
||||||
# Insert some test data
|
|
||||||
await authorized_store.insert(table_name, {"id": "1", "data": "public_data"})
|
|
||||||
|
|
||||||
# Test fetching with no user - should not error on JSON comparison
|
@pytest.mark.asyncio
|
||||||
result = await authorized_store.fetch_all(table_name, policy=default_policy())
|
@pytest.mark.parametrize("backend_config", BACKEND_CONFIGS)
|
||||||
assert len(result.data) == 1
|
@patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user")
|
||||||
assert result.data[0]["id"] == "1"
|
async def test_user_ownership_policy(mock_get_authenticated_user, authorized_store, request):
|
||||||
assert result.data[0]["access_attributes"] is None
|
"""Test that 'user is owner' policies work correctly with record ownership"""
|
||||||
|
from llama_stack.distribution.access_control.datatypes import AccessRule, Action, Scope
|
||||||
|
|
||||||
# Test with authenticated user
|
backend_name = request.node.callspec.id
|
||||||
test_user = User("test-user", {"roles": ["admin"]})
|
|
||||||
mock_get_authenticated_user.return_value = test_user
|
|
||||||
|
|
||||||
# Insert data with user attributes
|
# Create test table
|
||||||
await authorized_store.insert(table_name, {"id": "2", "data": "admin_data"})
|
table_name = f"test_ownership_{backend_name}"
|
||||||
|
await create_test_table(authorized_store, table_name)
|
||||||
|
|
||||||
# Fetch all - admin should see both
|
try:
|
||||||
result = await authorized_store.fetch_all(table_name, policy=default_policy())
|
# Test with first user who creates records
|
||||||
assert len(result.data) == 2
|
user1 = User("user1", {"roles": ["admin"]})
|
||||||
|
mock_get_authenticated_user.return_value = user1
|
||||||
|
|
||||||
# Test with non-admin user
|
# Insert a record owned by user1
|
||||||
regular_user = User("regular-user", {"roles": ["user"]})
|
await authorized_store.insert(table_name, {"id": "1", "data": "user1_data"})
|
||||||
mock_get_authenticated_user.return_value = regular_user
|
|
||||||
|
|
||||||
# Should only see public record
|
# Test with second user
|
||||||
result = await authorized_store.fetch_all(table_name, policy=default_policy())
|
user2 = User("user2", {"roles": ["user"]})
|
||||||
assert len(result.data) == 1
|
mock_get_authenticated_user.return_value = user2
|
||||||
assert result.data[0]["id"] == "1"
|
|
||||||
|
|
||||||
# Test the category missing branch: user with multiple attributes
|
# Insert a record owned by user2
|
||||||
multi_user = User("multi-user", {"roles": ["admin"], "teams": ["dev"]})
|
await authorized_store.insert(table_name, {"id": "2", "data": "user2_data"})
|
||||||
mock_get_authenticated_user.return_value = multi_user
|
|
||||||
|
|
||||||
# Insert record with multi-user (has both roles and teams)
|
# Create a policy that only allows access when user is the owner
|
||||||
await authorized_store.insert(table_name, {"id": "3", "data": "multi_user_data"})
|
owner_only_policy = [
|
||||||
|
AccessRule(
|
||||||
|
permit=Scope(actions=[Action.READ]),
|
||||||
|
when=["user is owner"],
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
# Test different user types to create records with different attribute patterns
|
# Test user1 access - should only see their own record
|
||||||
# Record with only roles (teams category will be missing)
|
mock_get_authenticated_user.return_value = user1
|
||||||
roles_only_user = User("roles-user", {"roles": ["admin"]})
|
result = await authorized_store.fetch_all(table_name, policy=owner_only_policy)
|
||||||
mock_get_authenticated_user.return_value = roles_only_user
|
assert len(result.data) == 1, f"Expected user1 to see 1 record, got {len(result.data)}"
|
||||||
await authorized_store.insert(table_name, {"id": "4", "data": "roles_only_data"})
|
assert result.data[0]["id"] == "1", f"Expected user1's record, got {result.data[0]['id']}"
|
||||||
|
|
||||||
# Record with only teams (roles category will be missing)
|
# Test user2 access - should only see their own record
|
||||||
teams_only_user = User("teams-user", {"teams": ["dev"]})
|
mock_get_authenticated_user.return_value = user2
|
||||||
mock_get_authenticated_user.return_value = teams_only_user
|
result = await authorized_store.fetch_all(table_name, policy=owner_only_policy)
|
||||||
await authorized_store.insert(table_name, {"id": "5", "data": "teams_only_data"})
|
assert len(result.data) == 1, f"Expected user2 to see 1 record, got {len(result.data)}"
|
||||||
|
assert result.data[0]["id"] == "2", f"Expected user2's record, got {result.data[0]['id']}"
|
||||||
|
|
||||||
# Record with different roles/teams (shouldn't match our test user)
|
# Test with anonymous user - should see no records
|
||||||
different_user = User("different-user", {"roles": ["user"], "teams": ["qa"]})
|
mock_get_authenticated_user.return_value = None
|
||||||
mock_get_authenticated_user.return_value = different_user
|
result = await authorized_store.fetch_all(table_name, policy=owner_only_policy)
|
||||||
await authorized_store.insert(table_name, {"id": "6", "data": "different_user_data"})
|
assert len(result.data) == 0, f"Expected anonymous user to see 0 records, got {len(result.data)}"
|
||||||
|
|
||||||
# Now test with the multi-user who has both roles=admin and teams=dev
|
|
||||||
mock_get_authenticated_user.return_value = multi_user
|
|
||||||
result = await authorized_store.fetch_all(table_name, policy=default_policy())
|
|
||||||
|
|
||||||
# Should see:
|
|
||||||
# - public record (1) - no access_attributes
|
|
||||||
# - admin record (2) - user matches roles=admin, teams missing (allowed)
|
|
||||||
# - multi_user record (3) - user matches both roles=admin and teams=dev
|
|
||||||
# - roles_only record (4) - user matches roles=admin, teams missing (allowed)
|
|
||||||
# - teams_only record (5) - user matches teams=dev, roles missing (allowed)
|
|
||||||
# Should NOT see:
|
|
||||||
# - different_user record (6) - user doesn't match roles=user or teams=qa
|
|
||||||
expected_ids = {"1", "2", "3", "4", "5"}
|
|
||||||
actual_ids = {record["id"] for record in result.data}
|
|
||||||
assert actual_ids == expected_ids, f"Expected to see records {expected_ids} but got {actual_ids}"
|
|
||||||
|
|
||||||
# Verify the category missing logic specifically
|
|
||||||
# Records 4 and 5 test the "category missing" branch where one attribute category is missing
|
|
||||||
category_test_ids = {record["id"] for record in result.data if record["id"] in ["4", "5"]}
|
|
||||||
assert category_test_ids == {"4", "5"}, (
|
|
||||||
f"Category missing logic failed: expected 4,5 but got {category_test_ids}"
|
|
||||||
)
|
|
||||||
|
|
||||||
finally:
|
|
||||||
# Clean up records
|
|
||||||
for record_id in ["1", "2", "3", "4", "5", "6"]:
|
|
||||||
try:
|
|
||||||
await base_sqlstore.delete(table_name, {"id": record_id})
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
# Clean up temporary SQLite database file if needed
|
# Clean up records
|
||||||
if cleanup_path:
|
await cleanup_records(authorized_store.sql_store, table_name, ["1", "2"])
|
||||||
try:
|
|
||||||
os.unlink(cleanup_path)
|
|
||||||
except OSError:
|
|
||||||
pass
|
|
||||||
|
|
|
@ -153,7 +153,9 @@ async def test_sql_policy_consistency(mock_get_authenticated_user):
|
||||||
policy_ids = set()
|
policy_ids = set()
|
||||||
for scenario in test_scenarios:
|
for scenario in test_scenarios:
|
||||||
sql_record = SqlRecord(
|
sql_record = SqlRecord(
|
||||||
record_id=scenario["id"], table_name="resources", access_attributes=scenario["access_attributes"]
|
record_id=scenario["id"],
|
||||||
|
table_name="resources",
|
||||||
|
owner=User(principal="test-user", attributes=scenario["access_attributes"]),
|
||||||
)
|
)
|
||||||
|
|
||||||
if is_action_allowed(policy, Action.READ, sql_record, user):
|
if is_action_allowed(policy, Action.READ, sql_record, user):
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue