From d3b60507d7fe93f0f24bc55e9f8fc9279bf7d11e Mon Sep 17 00:00:00 2001 From: ehhuang Date: Fri, 20 Jun 2025 10:24:45 -0700 Subject: [PATCH] feat: support auth attributes in inference/responses stores (#2389) # What does this PR do? Inference/Response stores now store user attributes when inserting, and respects them when fetching. ## Test Plan pytest tests/unit/utils/test_sqlstore.py --- llama_stack/distribution/resolver.py | 2 +- llama_stack/distribution/routers/__init__.py | 4 +- .../inline/agents/meta_reference/agents.py | 2 +- .../utils/inference/inference_store.py | 22 +- .../utils/responses/responses_store.py | 21 +- llama_stack/providers/utils/sqlstore/api.py | 26 +- .../utils/sqlstore/authorized_sqlstore.py | 222 ++++++++++++++++++ .../utils/sqlstore/sqlalchemy_sqlstore.py | 61 ++++- .../meta_reference/test_openai_responses.py | 3 +- .../utils/inference/test_inference_store.py | 10 +- .../utils/responses/test_responses_store.py | 16 +- tests/unit/utils/test_authorized_sqlstore.py | 218 +++++++++++++++++ 12 files changed, 575 insertions(+), 32 deletions(-) create mode 100644 llama_stack/providers/utils/sqlstore/authorized_sqlstore.py create mode 100644 tests/unit/utils/test_authorized_sqlstore.py diff --git a/llama_stack/distribution/resolver.py b/llama_stack/distribution/resolver.py index e71ff8092..3726bb3a5 100644 --- a/llama_stack/distribution/resolver.py +++ b/llama_stack/distribution/resolver.py @@ -335,7 +335,7 @@ async def instantiate_provider( method = "get_auto_router_impl" config = None - args = [provider_spec.api, deps[provider_spec.routing_table_api], deps, run_config] + args = [provider_spec.api, deps[provider_spec.routing_table_api], deps, run_config, policy] elif isinstance(provider_spec, RoutingTableProviderSpec): method = "get_routing_table_impl" diff --git a/llama_stack/distribution/routers/__init__.py b/llama_stack/distribution/routers/__init__.py index 0a0c13880..8671a62e1 100644 --- a/llama_stack/distribution/routers/__init__.py +++ b/llama_stack/distribution/routers/__init__.py @@ -47,7 +47,7 @@ async def get_routing_table_impl( async def get_auto_router_impl( - api: Api, routing_table: RoutingTable, deps: dict[str, Any], run_config: StackRunConfig + api: Api, routing_table: RoutingTable, deps: dict[str, Any], run_config: StackRunConfig, policy: list[AccessRule] ) -> Any: from .datasets import DatasetIORouter from .eval_scoring import EvalRouter, ScoringRouter @@ -78,7 +78,7 @@ async def get_auto_router_impl( # TODO: move pass configs to routers instead if api == Api.inference and run_config.inference_store: - inference_store = InferenceStore(run_config.inference_store) + inference_store = InferenceStore(run_config.inference_store, policy) await inference_store.initialize() api_to_dep_impl["store"] = inference_store diff --git a/llama_stack/providers/inline/agents/meta_reference/agents.py b/llama_stack/providers/inline/agents/meta_reference/agents.py index 6b2acd8f3..89fadafb4 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agents.py +++ b/llama_stack/providers/inline/agents/meta_reference/agents.py @@ -78,7 +78,7 @@ class MetaReferenceAgentsImpl(Agents): async def initialize(self) -> None: self.persistence_store = await kvstore_impl(self.config.persistence_store) - self.responses_store = ResponsesStore(self.config.responses_store) + self.responses_store = ResponsesStore(self.config.responses_store, self.policy) await self.responses_store.initialize() self.openai_responses_impl = OpenAIResponsesImpl( inference_api=self.inference_api, diff --git a/llama_stack/providers/utils/inference/inference_store.py b/llama_stack/providers/utils/inference/inference_store.py index ab43e48b1..60a87494e 100644 --- a/llama_stack/providers/utils/inference/inference_store.py +++ b/llama_stack/providers/utils/inference/inference_store.py @@ -10,24 +10,27 @@ from llama_stack.apis.inference import ( OpenAIMessageParam, Order, ) +from llama_stack.distribution.datatypes import AccessRule from llama_stack.distribution.utils.config_dirs import RUNTIME_BASE_DIR from ..sqlstore.api import ColumnDefinition, ColumnType +from ..sqlstore.authorized_sqlstore import AuthorizedSqlStore from ..sqlstore.sqlstore import SqliteSqlStoreConfig, SqlStoreConfig, sqlstore_impl class InferenceStore: - def __init__(self, sql_store_config: SqlStoreConfig): + def __init__(self, sql_store_config: SqlStoreConfig, policy: list[AccessRule]): if not sql_store_config: sql_store_config = SqliteSqlStoreConfig( db_path=(RUNTIME_BASE_DIR / "sqlstore.db").as_posix(), ) self.sql_store_config = sql_store_config self.sql_store = None + self.policy = policy async def initialize(self): """Create the necessary tables if they don't exist.""" - self.sql_store = sqlstore_impl(self.sql_store_config) + self.sql_store = AuthorizedSqlStore(sqlstore_impl(self.sql_store_config)) await self.sql_store.create_table( "chat_completions", { @@ -48,8 +51,8 @@ class InferenceStore: data = chat_completion.model_dump() await self.sql_store.insert( - "chat_completions", - { + table="chat_completions", + data={ "id": data["id"], "created": data["created"], "model": data["model"], @@ -89,6 +92,7 @@ class InferenceStore: order_by=[("created", order.value)], cursor=("id", after) if after else None, limit=limit, + policy=self.policy, ) data = [ @@ -112,9 +116,17 @@ class InferenceStore: if not self.sql_store: raise ValueError("Inference store is not initialized") - row = await self.sql_store.fetch_one("chat_completions", where={"id": completion_id}) + row = await self.sql_store.fetch_one( + table="chat_completions", + where={"id": completion_id}, + policy=self.policy, + ) + if not row: + # SecureSqlStore will return None if record doesn't exist OR access is denied + # This provides security by not revealing whether the record exists raise ValueError(f"Chat completion with id {completion_id} not found") from None + return OpenAICompletionWithInputMessages( id=row["id"], created=row["created"], diff --git a/llama_stack/providers/utils/responses/responses_store.py b/llama_stack/providers/utils/responses/responses_store.py index 151b020f7..36151d1c3 100644 --- a/llama_stack/providers/utils/responses/responses_store.py +++ b/llama_stack/providers/utils/responses/responses_store.py @@ -13,19 +13,22 @@ from llama_stack.apis.agents.openai_responses import ( OpenAIResponseObject, OpenAIResponseObjectWithInput, ) +from llama_stack.distribution.datatypes import AccessRule from llama_stack.distribution.utils.config_dirs import RUNTIME_BASE_DIR from ..sqlstore.api import ColumnDefinition, ColumnType +from ..sqlstore.authorized_sqlstore import AuthorizedSqlStore from ..sqlstore.sqlstore import SqliteSqlStoreConfig, SqlStoreConfig, sqlstore_impl class ResponsesStore: - def __init__(self, sql_store_config: SqlStoreConfig): + def __init__(self, sql_store_config: SqlStoreConfig, policy: list[AccessRule]): if not sql_store_config: sql_store_config = SqliteSqlStoreConfig( db_path=(RUNTIME_BASE_DIR / "sqlstore.db").as_posix(), ) - self.sql_store = sqlstore_impl(sql_store_config) + self.sql_store = AuthorizedSqlStore(sqlstore_impl(sql_store_config)) + self.policy = policy async def initialize(self): """Create the necessary tables if they don't exist.""" @@ -83,6 +86,7 @@ class ResponsesStore: order_by=[("created_at", order.value)], cursor=("id", after) if after else None, limit=limit, + policy=self.policy, ) data = [OpenAIResponseObjectWithInput(**row["response_object"]) for row in paginated_result.data] @@ -94,9 +98,20 @@ class ResponsesStore: ) async def get_response_object(self, response_id: str) -> OpenAIResponseObjectWithInput: - row = await self.sql_store.fetch_one("openai_responses", where={"id": response_id}) + """ + Get a response object with automatic access control checking. + """ + row = await self.sql_store.fetch_one( + "openai_responses", + where={"id": response_id}, + policy=self.policy, + ) + if not row: + # SecureSqlStore will return None if record doesn't exist OR access is denied + # This provides security by not revealing whether the record exists raise ValueError(f"Response with id {response_id} not found") from None + return OpenAIResponseObjectWithInput(**row["response_object"]) async def list_response_input_items( diff --git a/llama_stack/providers/utils/sqlstore/api.py b/llama_stack/providers/utils/sqlstore/api.py index 248dbb38e..6bb85ea0c 100644 --- a/llama_stack/providers/utils/sqlstore/api.py +++ b/llama_stack/providers/utils/sqlstore/api.py @@ -51,6 +51,7 @@ class SqlStore(Protocol): self, table: str, where: Mapping[str, Any] | None = None, + where_sql: str | None = None, limit: int | None = None, order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None, cursor: tuple[str, str] | None = None, @@ -59,7 +60,8 @@ class SqlStore(Protocol): Fetch all rows from a table with optional cursor-based pagination. :param table: The table name - :param where: WHERE conditions + :param where: Simple key-value WHERE conditions + :param where_sql: Raw SQL WHERE clause for complex queries :param limit: Maximum number of records to return :param order_by: List of (column, order) tuples for sorting :param cursor: Tuple of (key_column, cursor_id) for pagination (None for first page) @@ -75,6 +77,7 @@ class SqlStore(Protocol): self, table: str, where: Mapping[str, Any] | None = None, + where_sql: str | None = None, order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None, ) -> dict[str, Any] | None: """ @@ -102,3 +105,24 @@ class SqlStore(Protocol): Delete a row from a table. """ pass + + async def add_column_if_not_exists( + self, + table: str, + column_name: str, + column_type: ColumnType, + nullable: bool = True, + ) -> None: + """ + Add a column to an existing table if the column doesn't already exist. + + This is useful for table migrations when adding new functionality. + If the table doesn't exist, this method should do nothing. + If the column already exists, this method should do nothing. + + :param table: Table name + :param column_name: Name of the column to add + :param column_type: Type of the column to add + :param nullable: Whether the column should be nullable (default: True) + """ + pass diff --git a/llama_stack/providers/utils/sqlstore/authorized_sqlstore.py b/llama_stack/providers/utils/sqlstore/authorized_sqlstore.py new file mode 100644 index 000000000..520c8203b --- /dev/null +++ b/llama_stack/providers/utils/sqlstore/authorized_sqlstore.py @@ -0,0 +1,222 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from collections.abc import Mapping +from typing import Any, Literal + +from llama_stack.distribution.access_control.access_control import default_policy, is_action_allowed +from llama_stack.distribution.access_control.conditions import ProtectedResource +from llama_stack.distribution.access_control.datatypes import AccessRule, Action, Scope +from llama_stack.distribution.datatypes import User +from llama_stack.distribution.request_headers import get_authenticated_user +from llama_stack.log import get_logger + +from .api import ColumnDefinition, ColumnType, PaginatedResponse, SqlStore + +logger = get_logger(name=__name__, category="authorized_sqlstore") + +# Hardcoded copy of the default policy that our SQL filtering implements +# WARNING: If default_policy() changes, this constant must be updated accordingly +# or SQL filtering will fall back to conservative mode (safe but less performant) +# +# This policy represents: "Permit all actions when user is in owners list for ALL attribute categories" +# The corresponding SQL logic is implemented in _build_default_policy_where_clause(): +# - Public records (no access_attributes) are always accessible +# - Records with access_attributes require user to match ALL categories that exist in the resource +# - Missing categories in the resource are treated as "no restriction" (allow) +# - Within each category, user needs ANY matching value (OR logic) +# - Between categories, user needs ALL categories to match (AND logic) +SQL_OPTIMIZED_POLICY = [ + AccessRule( + permit=Scope(actions=list(Action)), + when=["user in owners roles", "user in owners teams", "user in owners projects", "user in owners namespaces"], + ), +] + + +class SqlRecord(ProtectedResource): + """Simple ProtectedResource implementation for SQL records.""" + + def __init__(self, record_id: str, table_name: str, access_attributes: dict[str, list[str]] | None = None): + self.type = f"sql_record::{table_name}" + self.identifier = record_id + + if access_attributes: + self.owner = User( + principal="system", + attributes=access_attributes, + ) + else: + self.owner = User( + principal="system_public", + attributes=None, + ) + + +class AuthorizedSqlStore: + """ + Authorization layer for SqlStore that provides access control functionality. + + This class composes a base SqlStore and adds authorization methods that handle + access control policies, user attribute capture, and SQL filtering optimization. + """ + + def __init__(self, sql_store: SqlStore): + """ + Initialize the authorization layer. + + :param sql_store: Base SqlStore implementation to wrap + """ + self.sql_store = sql_store + + self._validate_sql_optimized_policy() + + def _validate_sql_optimized_policy(self) -> None: + """Validate that SQL_OPTIMIZED_POLICY matches the actual default_policy(). + + This ensures that if default_policy() changes, we detect the mismatch and + can update our SQL filtering logic accordingly. + """ + actual_default = default_policy() + + if SQL_OPTIMIZED_POLICY != actual_default: + logger.warning( + f"SQL_OPTIMIZED_POLICY does not match default_policy(). " + f"SQL filtering will use conservative mode. " + f"Expected: {SQL_OPTIMIZED_POLICY}, Got: {actual_default}", + ) + + async def create_table(self, table: str, schema: Mapping[str, ColumnType | ColumnDefinition]) -> None: + """Create a table with built-in access control support.""" + await self.sql_store.add_column_if_not_exists(table, "access_attributes", ColumnType.JSON) + + enhanced_schema = dict(schema) + if "access_attributes" not in enhanced_schema: + enhanced_schema["access_attributes"] = ColumnType.JSON + + await self.sql_store.create_table(table, enhanced_schema) + + async def insert(self, table: str, data: Mapping[str, Any]) -> None: + """Insert a row with automatic access control attribute capture.""" + enhanced_data = dict(data) + + current_user = get_authenticated_user() + if current_user and current_user.attributes: + enhanced_data["access_attributes"] = current_user.attributes + else: + enhanced_data["access_attributes"] = None + + await self.sql_store.insert(table, enhanced_data) + + async def fetch_all( + self, + table: str, + policy: list[AccessRule], + where: Mapping[str, Any] | None = None, + limit: int | None = None, + order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None, + cursor: tuple[str, str] | None = None, + ) -> PaginatedResponse: + """Fetch all rows with automatic access control filtering.""" + access_where = self._build_access_control_where_clause(policy) + rows = await self.sql_store.fetch_all( + table=table, + where=where, + where_sql=access_where, + limit=limit, + order_by=order_by, + cursor=cursor, + ) + + current_user = get_authenticated_user() + filtered_rows = [] + + for row in rows.data: + stored_access_attrs = row.get("access_attributes") + + record_id = row.get("id", "unknown") + sql_record = SqlRecord(str(record_id), table, stored_access_attrs) + + if is_action_allowed(policy, Action.READ, sql_record, current_user): + filtered_rows.append(row) + + return PaginatedResponse( + data=filtered_rows, + has_more=rows.has_more, + ) + + async def fetch_one( + self, + table: str, + policy: list[AccessRule], + where: Mapping[str, Any] | None = None, + order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None, + ) -> dict[str, Any] | None: + """Fetch one row with automatic access control checking.""" + results = await self.fetch_all( + table=table, + policy=policy, + where=where, + limit=1, + order_by=order_by, + ) + + return results.data[0] if results.data else None + + def _build_access_control_where_clause(self, policy: list[AccessRule]) -> str: + """Build SQL WHERE clause for access control filtering. + + Only applies SQL filtering for the default policy to ensure correctness. + For custom policies, uses conservative filtering to avoid blocking legitimate access. + """ + if not policy or policy == SQL_OPTIMIZED_POLICY: + return self._build_default_policy_where_clause() + else: + return self._build_conservative_where_clause() + + def _build_default_policy_where_clause(self) -> str: + """Build SQL WHERE clause for the default policy. + + Default policy: permit all actions when user in owners [roles, teams, projects, namespaces] + This means user must match ALL attribute categories that exist in the resource. + """ + current_user = get_authenticated_user() + + if not current_user or not current_user.attributes: + return "(access_attributes IS NULL OR access_attributes = 'null' OR access_attributes = '{}')" + else: + base_conditions = ["access_attributes IS NULL", "access_attributes = 'null'", "access_attributes = '{}'"] + + user_attr_conditions = [] + + for attr_key, user_values in current_user.attributes.items(): + if user_values: + value_conditions = [] + for value in user_values: + value_conditions.append(f"JSON_EXTRACT(access_attributes, '$.{attr_key}') LIKE '%\"{value}\"%'") + + if value_conditions: + category_missing = f"JSON_EXTRACT(access_attributes, '$.{attr_key}') IS NULL" + user_matches_category = f"({' OR '.join(value_conditions)})" + user_attr_conditions.append(f"({category_missing} OR {user_matches_category})") + + if user_attr_conditions: + all_requirements_met = f"({' AND '.join(user_attr_conditions)})" + base_conditions.append(all_requirements_met) + return f"({' OR '.join(base_conditions)})" + else: + return f"({' OR '.join(base_conditions)})" + + def _build_conservative_where_clause(self) -> str: + """Conservative SQL filtering for custom policies. + + Only filters records we're 100% certain would be denied by any reasonable policy. + """ + current_user = get_authenticated_user() + + if not current_user: + return "(access_attributes IS NULL OR access_attributes = 'null' OR access_attributes = '{}')" + return "1=1" diff --git a/llama_stack/providers/utils/sqlstore/sqlalchemy_sqlstore.py b/llama_stack/providers/utils/sqlstore/sqlalchemy_sqlstore.py index db8180e74..3aecb0d59 100644 --- a/llama_stack/providers/utils/sqlstore/sqlalchemy_sqlstore.py +++ b/llama_stack/providers/utils/sqlstore/sqlalchemy_sqlstore.py @@ -17,15 +17,20 @@ from sqlalchemy import ( String, Table, Text, + inspect, select, + text, ) from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine from llama_stack.apis.common.responses import PaginatedResponse +from llama_stack.log import get_logger from .api import ColumnDefinition, ColumnType, SqlStore from .sqlstore import SqlAlchemySqlStoreConfig +logger = get_logger(name=__name__, category="sqlstore") + TYPE_MAPPING: dict[ColumnType, Any] = { ColumnType.INTEGER: Integer, ColumnType.STRING: String, @@ -56,7 +61,7 @@ class SqlAlchemySqlStoreImpl(SqlStore): for col_name, col_props in schema.items(): col_type = None is_primary_key = False - is_nullable = True # Default to nullable + is_nullable = True if isinstance(col_props, ColumnType): col_type = col_props @@ -73,14 +78,11 @@ class SqlAlchemySqlStoreImpl(SqlStore): Column(col_name, sqlalchemy_type, primary_key=is_primary_key, nullable=is_nullable) ) - # Check if table already exists in metadata, otherwise define it if table not in self.metadata.tables: sqlalchemy_table = Table(table, self.metadata, *sqlalchemy_columns) else: sqlalchemy_table = self.metadata.tables[table] - # Create the table in the database if it doesn't exist - # checkfirst=True ensures it doesn't try to recreate if it's already there engine = create_async_engine(self.config.engine_str) async with engine.begin() as conn: await conn.run_sync(self.metadata.create_all, tables=[sqlalchemy_table], checkfirst=True) @@ -94,6 +96,7 @@ class SqlAlchemySqlStoreImpl(SqlStore): self, table: str, where: Mapping[str, Any] | None = None, + where_sql: str | None = None, limit: int | None = None, order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None, cursor: tuple[str, str] | None = None, @@ -106,6 +109,9 @@ class SqlAlchemySqlStoreImpl(SqlStore): for key, value in where.items(): query = query.where(table_obj.c[key] == value) + if where_sql: + query = query.where(text(where_sql)) + # Handle cursor-based pagination if cursor: # Validate cursor tuple format @@ -192,9 +198,10 @@ class SqlAlchemySqlStoreImpl(SqlStore): self, table: str, where: Mapping[str, Any] | None = None, + where_sql: str | None = None, order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None, ) -> dict[str, Any] | None: - result = await self.fetch_all(table, where, limit=1, order_by=order_by) + result = await self.fetch_all(table, where, where_sql, limit=1, order_by=order_by) if not result.data: return None return result.data[0] @@ -225,3 +232,47 @@ class SqlAlchemySqlStoreImpl(SqlStore): stmt = stmt.where(self.metadata.tables[table].c[key] == value) await session.execute(stmt) await session.commit() + + async def add_column_if_not_exists( + self, + table: str, + column_name: str, + column_type: ColumnType, + nullable: bool = True, + ) -> None: + """Add a column to an existing table if the column doesn't already exist.""" + engine = create_async_engine(self.config.engine_str) + + try: + inspector = inspect(engine) + + table_names = inspector.get_table_names() + if table not in table_names: + return + + existing_columns = inspector.get_columns(table) + column_names = [col["name"] for col in existing_columns] + + if column_name in column_names: + return + + sqlalchemy_type = TYPE_MAPPING.get(column_type) + if not sqlalchemy_type: + raise ValueError(f"Unsupported column type '{column_type}' for column '{column_name}'.") + + # Create the ALTER TABLE statement + # Note: We need to get the dialect-specific type name + dialect = engine.dialect + type_impl = sqlalchemy_type() + compiled_type = type_impl.compile(dialect=dialect) + + nullable_clause = "" if nullable else " NOT NULL" + add_column_sql = text(f"ALTER TABLE {table} ADD COLUMN {column_name} {compiled_type}{nullable_clause}") + + async with engine.begin() as conn: + await conn.execute(add_column_sql) + + except Exception: + # If any error occurs during migration, log it but don't fail + # The table creation will handle adding the column + pass diff --git a/tests/unit/providers/agents/meta_reference/test_openai_responses.py b/tests/unit/providers/agents/meta_reference/test_openai_responses.py index 6bf1b7e0c..a3d798083 100644 --- a/tests/unit/providers/agents/meta_reference/test_openai_responses.py +++ b/tests/unit/providers/agents/meta_reference/test_openai_responses.py @@ -39,6 +39,7 @@ from llama_stack.apis.inference.inference import ( OpenAIUserMessageParam, ) from llama_stack.apis.tools.tools import Tool, ToolGroups, ToolInvocationResult, ToolParameter, ToolRuntime +from llama_stack.distribution.access_control.access_control import default_policy from llama_stack.providers.inline.agents.meta_reference.openai_responses import ( OpenAIResponsesImpl, ) @@ -599,7 +600,7 @@ async def test_responses_store_list_input_items_logic(): # Create mock store and response store mock_sql_store = AsyncMock() - responses_store = ResponsesStore(sql_store_config=None) + responses_store = ResponsesStore(sql_store_config=None, policy=default_policy()) responses_store.sql_store = mock_sql_store # Setup test data - multiple input items diff --git a/tests/unit/utils/inference/test_inference_store.py b/tests/unit/utils/inference/test_inference_store.py index b30748617..de619c760 100644 --- a/tests/unit/utils/inference/test_inference_store.py +++ b/tests/unit/utils/inference/test_inference_store.py @@ -47,7 +47,7 @@ async def test_inference_store_pagination_basic(): """Test basic pagination functionality.""" with TemporaryDirectory() as tmp_dir: db_path = tmp_dir + "/test.db" - store = InferenceStore(SqliteSqlStoreConfig(db_path=db_path)) + store = InferenceStore(SqliteSqlStoreConfig(db_path=db_path), policy=[]) await store.initialize() # Create test data with different timestamps @@ -93,7 +93,7 @@ async def test_inference_store_pagination_ascending(): """Test pagination with ascending order.""" with TemporaryDirectory() as tmp_dir: db_path = tmp_dir + "/test.db" - store = InferenceStore(SqliteSqlStoreConfig(db_path=db_path)) + store = InferenceStore(SqliteSqlStoreConfig(db_path=db_path), policy=[]) await store.initialize() # Create test data @@ -128,7 +128,7 @@ async def test_inference_store_pagination_with_model_filter(): """Test pagination combined with model filtering.""" with TemporaryDirectory() as tmp_dir: db_path = tmp_dir + "/test.db" - store = InferenceStore(SqliteSqlStoreConfig(db_path=db_path)) + store = InferenceStore(SqliteSqlStoreConfig(db_path=db_path), policy=[]) await store.initialize() # Create test data with different models @@ -166,7 +166,7 @@ async def test_inference_store_pagination_invalid_after(): """Test error handling for invalid 'after' parameter.""" with TemporaryDirectory() as tmp_dir: db_path = tmp_dir + "/test.db" - store = InferenceStore(SqliteSqlStoreConfig(db_path=db_path)) + store = InferenceStore(SqliteSqlStoreConfig(db_path=db_path), policy=[]) await store.initialize() # Try to paginate with non-existent ID @@ -179,7 +179,7 @@ async def test_inference_store_pagination_no_limit(): """Test pagination behavior when no limit is specified.""" with TemporaryDirectory() as tmp_dir: db_path = tmp_dir + "/test.db" - store = InferenceStore(SqliteSqlStoreConfig(db_path=db_path)) + store = InferenceStore(SqliteSqlStoreConfig(db_path=db_path), policy=[]) await store.initialize() # Create test data diff --git a/tests/unit/utils/responses/test_responses_store.py b/tests/unit/utils/responses/test_responses_store.py index 51fcb1ec2..3f25e2524 100644 --- a/tests/unit/utils/responses/test_responses_store.py +++ b/tests/unit/utils/responses/test_responses_store.py @@ -49,7 +49,7 @@ async def test_responses_store_pagination_basic(): """Test basic pagination functionality for responses store.""" with TemporaryDirectory() as tmp_dir: db_path = tmp_dir + "/test.db" - store = ResponsesStore(SqliteSqlStoreConfig(db_path=db_path)) + store = ResponsesStore(SqliteSqlStoreConfig(db_path=db_path), policy=[]) await store.initialize() # Create test data with different timestamps @@ -95,7 +95,7 @@ async def test_responses_store_pagination_ascending(): """Test pagination with ascending order.""" with TemporaryDirectory() as tmp_dir: db_path = tmp_dir + "/test.db" - store = ResponsesStore(SqliteSqlStoreConfig(db_path=db_path)) + store = ResponsesStore(SqliteSqlStoreConfig(db_path=db_path), policy=[]) await store.initialize() # Create test data @@ -130,7 +130,7 @@ async def test_responses_store_pagination_with_model_filter(): """Test pagination combined with model filtering.""" with TemporaryDirectory() as tmp_dir: db_path = tmp_dir + "/test.db" - store = ResponsesStore(SqliteSqlStoreConfig(db_path=db_path)) + store = ResponsesStore(SqliteSqlStoreConfig(db_path=db_path), policy=[]) await store.initialize() # Create test data with different models @@ -168,7 +168,7 @@ async def test_responses_store_pagination_invalid_after(): """Test error handling for invalid 'after' parameter.""" with TemporaryDirectory() as tmp_dir: db_path = tmp_dir + "/test.db" - store = ResponsesStore(SqliteSqlStoreConfig(db_path=db_path)) + store = ResponsesStore(SqliteSqlStoreConfig(db_path=db_path), policy=[]) await store.initialize() # Try to paginate with non-existent ID @@ -181,7 +181,7 @@ async def test_responses_store_pagination_no_limit(): """Test pagination behavior when no limit is specified.""" with TemporaryDirectory() as tmp_dir: db_path = tmp_dir + "/test.db" - store = ResponsesStore(SqliteSqlStoreConfig(db_path=db_path)) + store = ResponsesStore(SqliteSqlStoreConfig(db_path=db_path), policy=[]) await store.initialize() # Create test data @@ -210,7 +210,7 @@ async def test_responses_store_get_response_object(): """Test retrieving a single response object.""" with TemporaryDirectory() as tmp_dir: db_path = tmp_dir + "/test.db" - store = ResponsesStore(SqliteSqlStoreConfig(db_path=db_path)) + store = ResponsesStore(SqliteSqlStoreConfig(db_path=db_path), policy=[]) await store.initialize() # Store a test response @@ -235,7 +235,7 @@ async def test_responses_store_input_items_pagination(): """Test pagination functionality for input items.""" with TemporaryDirectory() as tmp_dir: db_path = tmp_dir + "/test.db" - store = ResponsesStore(SqliteSqlStoreConfig(db_path=db_path)) + store = ResponsesStore(SqliteSqlStoreConfig(db_path=db_path), policy=[]) await store.initialize() # Store a test response with many inputs with explicit IDs @@ -313,7 +313,7 @@ async def test_responses_store_input_items_before_pagination(): """Test before pagination functionality for input items.""" with TemporaryDirectory() as tmp_dir: db_path = tmp_dir + "/test.db" - store = ResponsesStore(SqliteSqlStoreConfig(db_path=db_path)) + store = ResponsesStore(SqliteSqlStoreConfig(db_path=db_path), policy=[]) await store.initialize() # Store a test response with many inputs with explicit IDs diff --git a/tests/unit/utils/test_authorized_sqlstore.py b/tests/unit/utils/test_authorized_sqlstore.py new file mode 100644 index 000000000..b457176a7 --- /dev/null +++ b/tests/unit/utils/test_authorized_sqlstore.py @@ -0,0 +1,218 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from tempfile import TemporaryDirectory +from unittest.mock import patch + +import pytest + +from llama_stack.distribution.access_control.access_control import default_policy, is_action_allowed +from llama_stack.distribution.access_control.datatypes import Action +from llama_stack.distribution.datatypes import User +from llama_stack.providers.utils.sqlstore.api import ColumnType +from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore, SqlRecord +from llama_stack.providers.utils.sqlstore.sqlalchemy_sqlstore import SqlAlchemySqlStoreImpl +from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig + + +@pytest.mark.asyncio +@patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user") +async def test_authorized_fetch_with_where_sql_access_control(mock_get_authenticated_user): + """Test that fetch_all works correctly with where_sql for access control""" + with TemporaryDirectory() as tmp_dir: + db_name = "test_access_control.db" + base_sqlstore = SqlAlchemySqlStoreImpl( + SqliteSqlStoreConfig( + db_path=tmp_dir + "/" + db_name, + ) + ) + sqlstore = AuthorizedSqlStore(base_sqlstore) + + # Create table with access control + await sqlstore.create_table( + table="documents", + schema={ + "id": ColumnType.INTEGER, + "title": ColumnType.STRING, + "content": ColumnType.TEXT, + }, + ) + + admin_user = User("admin-user", {"roles": ["admin"], "teams": ["engineering"]}) + regular_user = User("regular-user", {"roles": ["user"], "teams": ["marketing"]}) + + # Set user attributes for creating documents + mock_get_authenticated_user.return_value = admin_user + + # Insert documents with access attributes + await sqlstore.insert("documents", {"id": 1, "title": "Admin Document", "content": "This is admin content"}) + + # Change user attributes + mock_get_authenticated_user.return_value = regular_user + + await sqlstore.insert("documents", {"id": 2, "title": "User Document", "content": "Public user content"}) + + # Test that access control works with where parameter + mock_get_authenticated_user.return_value = admin_user + + # Admin should see both documents + result = await sqlstore.fetch_all("documents", policy=default_policy(), where={"id": 1}) + assert len(result.data) == 1 + assert result.data[0]["title"] == "Admin Document" + + # User should only see their document + mock_get_authenticated_user.return_value = regular_user + + result = await sqlstore.fetch_all("documents", policy=default_policy(), where={"id": 1}) + assert len(result.data) == 0 + + result = await sqlstore.fetch_all("documents", policy=default_policy(), where={"id": 2}) + assert len(result.data) == 1 + assert result.data[0]["title"] == "User Document" + + row = await sqlstore.fetch_one("documents", policy=default_policy(), where={"id": 1}) + assert row is None + + row = await sqlstore.fetch_one("documents", policy=default_policy(), where={"id": 2}) + assert row is not None + assert row["title"] == "User Document" + + +@pytest.mark.asyncio +@patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user") +async def test_sql_policy_consistency(mock_get_authenticated_user): + """Test that SQL WHERE clause logic exactly matches is_action_allowed policy logic""" + with TemporaryDirectory() as tmp_dir: + db_name = "test_consistency.db" + base_sqlstore = SqlAlchemySqlStoreImpl( + SqliteSqlStoreConfig( + db_path=tmp_dir + "/" + db_name, + ) + ) + sqlstore = AuthorizedSqlStore(base_sqlstore) + + await sqlstore.create_table( + table="resources", + schema={ + "id": ColumnType.STRING, + "name": ColumnType.STRING, + }, + ) + + # Test scenarios with different access control patterns + test_scenarios = [ + # Scenario 1: Public record (no access control) + {"id": "1", "name": "public", "access_attributes": None}, + # Scenario 2: Empty access control (should be treated as public) + {"id": "2", "name": "empty", "access_attributes": {}}, + # Scenario 3: Record with roles requirement + {"id": "3", "name": "admin-only", "access_attributes": {"roles": ["admin"]}}, + # Scenario 4: Record with multiple attribute categories + {"id": "4", "name": "admin-ml-team", "access_attributes": {"roles": ["admin"], "teams": ["ml-team"]}}, + # Scenario 5: Record with teams only (missing roles category) + {"id": "5", "name": "ml-team-only", "access_attributes": {"teams": ["ml-team"]}}, + # Scenario 6: Record with roles and projects + { + "id": "6", + "name": "admin-project-x", + "access_attributes": {"roles": ["admin"], "projects": ["project-x"]}, + }, + ] + + mock_get_authenticated_user.return_value = User("test-user", {"roles": ["admin"]}) + for scenario in test_scenarios: + await base_sqlstore.insert("resources", scenario) + + # Test with different user configurations + user_scenarios = [ + # User 1: No attributes (should only see public records) + {"principal": "user1", "attributes": None}, + # User 2: Empty attributes (should only see public records) + {"principal": "user2", "attributes": {}}, + # User 3: Admin role only + {"principal": "user3", "attributes": {"roles": ["admin"]}}, + # User 4: ML team only + {"principal": "user4", "attributes": {"teams": ["ml-team"]}}, + # User 5: Admin + ML team + {"principal": "user5", "attributes": {"roles": ["admin"], "teams": ["ml-team"]}}, + # User 6: Admin + Project X + {"principal": "user6", "attributes": {"roles": ["admin"], "projects": ["project-x"]}}, + # User 7: Different role (should only see public) + {"principal": "user7", "attributes": {"roles": ["viewer"]}}, + ] + + policy = default_policy() + + for user_data in user_scenarios: + user = User(principal=user_data["principal"], attributes=user_data["attributes"]) + mock_get_authenticated_user.return_value = user + + sql_results = await sqlstore.fetch_all("resources", policy=policy) + sql_ids = {row["id"] for row in sql_results.data} + policy_ids = set() + for scenario in test_scenarios: + sql_record = SqlRecord( + record_id=scenario["id"], table_name="resources", access_attributes=scenario["access_attributes"] + ) + + if is_action_allowed(policy, Action.READ, sql_record, user): + policy_ids.add(scenario["id"]) + assert sql_ids == policy_ids, ( + f"Consistency failure for user {user.principal} with attributes {user.attributes}:\n" + f"SQL returned: {sorted(sql_ids)}\n" + f"Policy allows: {sorted(policy_ids)}\n" + f"Difference: SQL only: {sql_ids - policy_ids}, Policy only: {policy_ids - sql_ids}" + ) + + +@pytest.mark.asyncio +@patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user") +async def test_authorized_store_user_attribute_capture(mock_get_authenticated_user): + """Test that user attributes are properly captured during insert""" + with TemporaryDirectory() as tmp_dir: + db_name = "test_attributes.db" + base_sqlstore = SqlAlchemySqlStoreImpl( + SqliteSqlStoreConfig( + db_path=tmp_dir + "/" + db_name, + ) + ) + authorized_store = AuthorizedSqlStore(base_sqlstore) + + await authorized_store.create_table( + table="user_data", + schema={ + "id": ColumnType.STRING, + "content": ColumnType.STRING, + }, + ) + + mock_get_authenticated_user.return_value = User( + "user-with-attrs", {"roles": ["editor"], "teams": ["content"], "projects": ["blog"]} + ) + + await authorized_store.insert("user_data", {"id": "item1", "content": "User content"}) + + mock_get_authenticated_user.return_value = User("user-no-attrs", None) + + await authorized_store.insert("user_data", {"id": "item2", "content": "Public content"}) + + mock_get_authenticated_user.return_value = None + + await authorized_store.insert("user_data", {"id": "item3", "content": "Anonymous content"}) + result = await base_sqlstore.fetch_all("user_data", order_by=[("id", "asc")]) + assert len(result.data) == 3 + + # First item should have full attributes + assert result.data[0]["id"] == "item1" + assert result.data[0]["access_attributes"] == {"roles": ["editor"], "teams": ["content"], "projects": ["blog"]} + + # Second item should have null attributes (user with no attributes) + assert result.data[1]["id"] == "item2" + assert result.data[1]["access_attributes"] is None + + # Third item should have null attributes (no authenticated user) + assert result.data[2]["id"] == "item3" + assert result.data[2]["access_attributes"] is None