feat: support auth attributes in inference/responses stores (#2389)

# What does this PR do?
Inference/Response stores now store user attributes when inserting, and
respects them when fetching.

## Test Plan
pytest tests/unit/utils/test_sqlstore.py
This commit is contained in:
ehhuang 2025-06-20 10:24:45 -07:00 committed by GitHub
parent 7930c524f9
commit d3b60507d7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 575 additions and 32 deletions

View file

@ -335,7 +335,7 @@ async def instantiate_provider(
method = "get_auto_router_impl" method = "get_auto_router_impl"
config = None config = None
args = [provider_spec.api, deps[provider_spec.routing_table_api], deps, run_config] args = [provider_spec.api, deps[provider_spec.routing_table_api], deps, run_config, policy]
elif isinstance(provider_spec, RoutingTableProviderSpec): elif isinstance(provider_spec, RoutingTableProviderSpec):
method = "get_routing_table_impl" method = "get_routing_table_impl"

View file

@ -47,7 +47,7 @@ async def get_routing_table_impl(
async def get_auto_router_impl( async def get_auto_router_impl(
api: Api, routing_table: RoutingTable, deps: dict[str, Any], run_config: StackRunConfig api: Api, routing_table: RoutingTable, deps: dict[str, Any], run_config: StackRunConfig, policy: list[AccessRule]
) -> Any: ) -> Any:
from .datasets import DatasetIORouter from .datasets import DatasetIORouter
from .eval_scoring import EvalRouter, ScoringRouter from .eval_scoring import EvalRouter, ScoringRouter
@ -78,7 +78,7 @@ async def get_auto_router_impl(
# TODO: move pass configs to routers instead # TODO: move pass configs to routers instead
if api == Api.inference and run_config.inference_store: if api == Api.inference and run_config.inference_store:
inference_store = InferenceStore(run_config.inference_store) inference_store = InferenceStore(run_config.inference_store, policy)
await inference_store.initialize() await inference_store.initialize()
api_to_dep_impl["store"] = inference_store api_to_dep_impl["store"] = inference_store

View file

@ -78,7 +78,7 @@ class MetaReferenceAgentsImpl(Agents):
async def initialize(self) -> None: async def initialize(self) -> None:
self.persistence_store = await kvstore_impl(self.config.persistence_store) self.persistence_store = await kvstore_impl(self.config.persistence_store)
self.responses_store = ResponsesStore(self.config.responses_store) self.responses_store = ResponsesStore(self.config.responses_store, self.policy)
await self.responses_store.initialize() await self.responses_store.initialize()
self.openai_responses_impl = OpenAIResponsesImpl( self.openai_responses_impl = OpenAIResponsesImpl(
inference_api=self.inference_api, inference_api=self.inference_api,

View file

@ -10,24 +10,27 @@ from llama_stack.apis.inference import (
OpenAIMessageParam, OpenAIMessageParam,
Order, Order,
) )
from llama_stack.distribution.datatypes import AccessRule
from llama_stack.distribution.utils.config_dirs import RUNTIME_BASE_DIR from llama_stack.distribution.utils.config_dirs import RUNTIME_BASE_DIR
from ..sqlstore.api import ColumnDefinition, ColumnType from ..sqlstore.api import ColumnDefinition, ColumnType
from ..sqlstore.authorized_sqlstore import AuthorizedSqlStore
from ..sqlstore.sqlstore import SqliteSqlStoreConfig, SqlStoreConfig, sqlstore_impl from ..sqlstore.sqlstore import SqliteSqlStoreConfig, SqlStoreConfig, sqlstore_impl
class InferenceStore: class InferenceStore:
def __init__(self, sql_store_config: SqlStoreConfig): def __init__(self, sql_store_config: SqlStoreConfig, policy: list[AccessRule]):
if not sql_store_config: if not sql_store_config:
sql_store_config = SqliteSqlStoreConfig( sql_store_config = SqliteSqlStoreConfig(
db_path=(RUNTIME_BASE_DIR / "sqlstore.db").as_posix(), db_path=(RUNTIME_BASE_DIR / "sqlstore.db").as_posix(),
) )
self.sql_store_config = sql_store_config self.sql_store_config = sql_store_config
self.sql_store = None self.sql_store = None
self.policy = policy
async def initialize(self): async def initialize(self):
"""Create the necessary tables if they don't exist.""" """Create the necessary tables if they don't exist."""
self.sql_store = sqlstore_impl(self.sql_store_config) self.sql_store = AuthorizedSqlStore(sqlstore_impl(self.sql_store_config))
await self.sql_store.create_table( await self.sql_store.create_table(
"chat_completions", "chat_completions",
{ {
@ -48,8 +51,8 @@ class InferenceStore:
data = chat_completion.model_dump() data = chat_completion.model_dump()
await self.sql_store.insert( await self.sql_store.insert(
"chat_completions", table="chat_completions",
{ data={
"id": data["id"], "id": data["id"],
"created": data["created"], "created": data["created"],
"model": data["model"], "model": data["model"],
@ -89,6 +92,7 @@ class InferenceStore:
order_by=[("created", order.value)], order_by=[("created", order.value)],
cursor=("id", after) if after else None, cursor=("id", after) if after else None,
limit=limit, limit=limit,
policy=self.policy,
) )
data = [ data = [
@ -112,9 +116,17 @@ class InferenceStore:
if not self.sql_store: if not self.sql_store:
raise ValueError("Inference store is not initialized") raise ValueError("Inference store is not initialized")
row = await self.sql_store.fetch_one("chat_completions", where={"id": completion_id}) row = await self.sql_store.fetch_one(
table="chat_completions",
where={"id": completion_id},
policy=self.policy,
)
if not row: if not row:
# SecureSqlStore will return None if record doesn't exist OR access is denied
# This provides security by not revealing whether the record exists
raise ValueError(f"Chat completion with id {completion_id} not found") from None raise ValueError(f"Chat completion with id {completion_id} not found") from None
return OpenAICompletionWithInputMessages( return OpenAICompletionWithInputMessages(
id=row["id"], id=row["id"],
created=row["created"], created=row["created"],

View file

@ -13,19 +13,22 @@ from llama_stack.apis.agents.openai_responses import (
OpenAIResponseObject, OpenAIResponseObject,
OpenAIResponseObjectWithInput, OpenAIResponseObjectWithInput,
) )
from llama_stack.distribution.datatypes import AccessRule
from llama_stack.distribution.utils.config_dirs import RUNTIME_BASE_DIR from llama_stack.distribution.utils.config_dirs import RUNTIME_BASE_DIR
from ..sqlstore.api import ColumnDefinition, ColumnType from ..sqlstore.api import ColumnDefinition, ColumnType
from ..sqlstore.authorized_sqlstore import AuthorizedSqlStore
from ..sqlstore.sqlstore import SqliteSqlStoreConfig, SqlStoreConfig, sqlstore_impl from ..sqlstore.sqlstore import SqliteSqlStoreConfig, SqlStoreConfig, sqlstore_impl
class ResponsesStore: class ResponsesStore:
def __init__(self, sql_store_config: SqlStoreConfig): def __init__(self, sql_store_config: SqlStoreConfig, policy: list[AccessRule]):
if not sql_store_config: if not sql_store_config:
sql_store_config = SqliteSqlStoreConfig( sql_store_config = SqliteSqlStoreConfig(
db_path=(RUNTIME_BASE_DIR / "sqlstore.db").as_posix(), db_path=(RUNTIME_BASE_DIR / "sqlstore.db").as_posix(),
) )
self.sql_store = sqlstore_impl(sql_store_config) self.sql_store = AuthorizedSqlStore(sqlstore_impl(sql_store_config))
self.policy = policy
async def initialize(self): async def initialize(self):
"""Create the necessary tables if they don't exist.""" """Create the necessary tables if they don't exist."""
@ -83,6 +86,7 @@ class ResponsesStore:
order_by=[("created_at", order.value)], order_by=[("created_at", order.value)],
cursor=("id", after) if after else None, cursor=("id", after) if after else None,
limit=limit, limit=limit,
policy=self.policy,
) )
data = [OpenAIResponseObjectWithInput(**row["response_object"]) for row in paginated_result.data] data = [OpenAIResponseObjectWithInput(**row["response_object"]) for row in paginated_result.data]
@ -94,9 +98,20 @@ class ResponsesStore:
) )
async def get_response_object(self, response_id: str) -> OpenAIResponseObjectWithInput: async def get_response_object(self, response_id: str) -> OpenAIResponseObjectWithInput:
row = await self.sql_store.fetch_one("openai_responses", where={"id": response_id}) """
Get a response object with automatic access control checking.
"""
row = await self.sql_store.fetch_one(
"openai_responses",
where={"id": response_id},
policy=self.policy,
)
if not row: if not row:
# SecureSqlStore will return None if record doesn't exist OR access is denied
# This provides security by not revealing whether the record exists
raise ValueError(f"Response with id {response_id} not found") from None raise ValueError(f"Response with id {response_id} not found") from None
return OpenAIResponseObjectWithInput(**row["response_object"]) return OpenAIResponseObjectWithInput(**row["response_object"])
async def list_response_input_items( async def list_response_input_items(

View file

@ -51,6 +51,7 @@ class SqlStore(Protocol):
self, self,
table: str, table: str,
where: Mapping[str, Any] | None = None, where: Mapping[str, Any] | None = None,
where_sql: str | None = None,
limit: int | None = None, limit: int | None = None,
order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None, order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None,
cursor: tuple[str, str] | None = None, cursor: tuple[str, str] | None = None,
@ -59,7 +60,8 @@ class SqlStore(Protocol):
Fetch all rows from a table with optional cursor-based pagination. Fetch all rows from a table with optional cursor-based pagination.
:param table: The table name :param table: The table name
:param where: WHERE conditions :param where: Simple key-value WHERE conditions
:param where_sql: Raw SQL WHERE clause for complex queries
:param limit: Maximum number of records to return :param limit: Maximum number of records to return
:param order_by: List of (column, order) tuples for sorting :param order_by: List of (column, order) tuples for sorting
:param cursor: Tuple of (key_column, cursor_id) for pagination (None for first page) :param cursor: Tuple of (key_column, cursor_id) for pagination (None for first page)
@ -75,6 +77,7 @@ class SqlStore(Protocol):
self, self,
table: str, table: str,
where: Mapping[str, Any] | None = None, where: Mapping[str, Any] | None = None,
where_sql: str | None = None,
order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None, order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None,
) -> dict[str, Any] | None: ) -> dict[str, Any] | None:
""" """
@ -102,3 +105,24 @@ class SqlStore(Protocol):
Delete a row from a table. Delete a row from a table.
""" """
pass pass
async def add_column_if_not_exists(
self,
table: str,
column_name: str,
column_type: ColumnType,
nullable: bool = True,
) -> None:
"""
Add a column to an existing table if the column doesn't already exist.
This is useful for table migrations when adding new functionality.
If the table doesn't exist, this method should do nothing.
If the column already exists, this method should do nothing.
:param table: Table name
:param column_name: Name of the column to add
:param column_type: Type of the column to add
:param nullable: Whether the column should be nullable (default: True)
"""
pass

View file

@ -0,0 +1,222 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from collections.abc import Mapping
from typing import Any, Literal
from llama_stack.distribution.access_control.access_control import default_policy, is_action_allowed
from llama_stack.distribution.access_control.conditions import ProtectedResource
from llama_stack.distribution.access_control.datatypes import AccessRule, Action, Scope
from llama_stack.distribution.datatypes import User
from llama_stack.distribution.request_headers import get_authenticated_user
from llama_stack.log import get_logger
from .api import ColumnDefinition, ColumnType, PaginatedResponse, SqlStore
logger = get_logger(name=__name__, category="authorized_sqlstore")
# Hardcoded copy of the default policy that our SQL filtering implements
# WARNING: If default_policy() changes, this constant must be updated accordingly
# or SQL filtering will fall back to conservative mode (safe but less performant)
#
# This policy represents: "Permit all actions when user is in owners list for ALL attribute categories"
# The corresponding SQL logic is implemented in _build_default_policy_where_clause():
# - Public records (no access_attributes) are always accessible
# - Records with access_attributes require user to match ALL categories that exist in the resource
# - Missing categories in the resource are treated as "no restriction" (allow)
# - Within each category, user needs ANY matching value (OR logic)
# - Between categories, user needs ALL categories to match (AND logic)
SQL_OPTIMIZED_POLICY = [
AccessRule(
permit=Scope(actions=list(Action)),
when=["user in owners roles", "user in owners teams", "user in owners projects", "user in owners namespaces"],
),
]
class SqlRecord(ProtectedResource):
"""Simple ProtectedResource implementation for SQL records."""
def __init__(self, record_id: str, table_name: str, access_attributes: dict[str, list[str]] | None = None):
self.type = f"sql_record::{table_name}"
self.identifier = record_id
if access_attributes:
self.owner = User(
principal="system",
attributes=access_attributes,
)
else:
self.owner = User(
principal="system_public",
attributes=None,
)
class AuthorizedSqlStore:
"""
Authorization layer for SqlStore that provides access control functionality.
This class composes a base SqlStore and adds authorization methods that handle
access control policies, user attribute capture, and SQL filtering optimization.
"""
def __init__(self, sql_store: SqlStore):
"""
Initialize the authorization layer.
:param sql_store: Base SqlStore implementation to wrap
"""
self.sql_store = sql_store
self._validate_sql_optimized_policy()
def _validate_sql_optimized_policy(self) -> None:
"""Validate that SQL_OPTIMIZED_POLICY matches the actual default_policy().
This ensures that if default_policy() changes, we detect the mismatch and
can update our SQL filtering logic accordingly.
"""
actual_default = default_policy()
if SQL_OPTIMIZED_POLICY != actual_default:
logger.warning(
f"SQL_OPTIMIZED_POLICY does not match default_policy(). "
f"SQL filtering will use conservative mode. "
f"Expected: {SQL_OPTIMIZED_POLICY}, Got: {actual_default}",
)
async def create_table(self, table: str, schema: Mapping[str, ColumnType | ColumnDefinition]) -> None:
"""Create a table with built-in access control support."""
await self.sql_store.add_column_if_not_exists(table, "access_attributes", ColumnType.JSON)
enhanced_schema = dict(schema)
if "access_attributes" not in enhanced_schema:
enhanced_schema["access_attributes"] = ColumnType.JSON
await self.sql_store.create_table(table, enhanced_schema)
async def insert(self, table: str, data: Mapping[str, Any]) -> None:
"""Insert a row with automatic access control attribute capture."""
enhanced_data = dict(data)
current_user = get_authenticated_user()
if current_user and current_user.attributes:
enhanced_data["access_attributes"] = current_user.attributes
else:
enhanced_data["access_attributes"] = None
await self.sql_store.insert(table, enhanced_data)
async def fetch_all(
self,
table: str,
policy: list[AccessRule],
where: Mapping[str, Any] | None = None,
limit: int | None = None,
order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None,
cursor: tuple[str, str] | None = None,
) -> PaginatedResponse:
"""Fetch all rows with automatic access control filtering."""
access_where = self._build_access_control_where_clause(policy)
rows = await self.sql_store.fetch_all(
table=table,
where=where,
where_sql=access_where,
limit=limit,
order_by=order_by,
cursor=cursor,
)
current_user = get_authenticated_user()
filtered_rows = []
for row in rows.data:
stored_access_attrs = row.get("access_attributes")
record_id = row.get("id", "unknown")
sql_record = SqlRecord(str(record_id), table, stored_access_attrs)
if is_action_allowed(policy, Action.READ, sql_record, current_user):
filtered_rows.append(row)
return PaginatedResponse(
data=filtered_rows,
has_more=rows.has_more,
)
async def fetch_one(
self,
table: str,
policy: list[AccessRule],
where: Mapping[str, Any] | None = None,
order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None,
) -> dict[str, Any] | None:
"""Fetch one row with automatic access control checking."""
results = await self.fetch_all(
table=table,
policy=policy,
where=where,
limit=1,
order_by=order_by,
)
return results.data[0] if results.data else None
def _build_access_control_where_clause(self, policy: list[AccessRule]) -> str:
"""Build SQL WHERE clause for access control filtering.
Only applies SQL filtering for the default policy to ensure correctness.
For custom policies, uses conservative filtering to avoid blocking legitimate access.
"""
if not policy or policy == SQL_OPTIMIZED_POLICY:
return self._build_default_policy_where_clause()
else:
return self._build_conservative_where_clause()
def _build_default_policy_where_clause(self) -> str:
"""Build SQL WHERE clause for the default policy.
Default policy: permit all actions when user in owners [roles, teams, projects, namespaces]
This means user must match ALL attribute categories that exist in the resource.
"""
current_user = get_authenticated_user()
if not current_user or not current_user.attributes:
return "(access_attributes IS NULL OR access_attributes = 'null' OR access_attributes = '{}')"
else:
base_conditions = ["access_attributes IS NULL", "access_attributes = 'null'", "access_attributes = '{}'"]
user_attr_conditions = []
for attr_key, user_values in current_user.attributes.items():
if user_values:
value_conditions = []
for value in user_values:
value_conditions.append(f"JSON_EXTRACT(access_attributes, '$.{attr_key}') LIKE '%\"{value}\"%'")
if value_conditions:
category_missing = f"JSON_EXTRACT(access_attributes, '$.{attr_key}') IS NULL"
user_matches_category = f"({' OR '.join(value_conditions)})"
user_attr_conditions.append(f"({category_missing} OR {user_matches_category})")
if user_attr_conditions:
all_requirements_met = f"({' AND '.join(user_attr_conditions)})"
base_conditions.append(all_requirements_met)
return f"({' OR '.join(base_conditions)})"
else:
return f"({' OR '.join(base_conditions)})"
def _build_conservative_where_clause(self) -> str:
"""Conservative SQL filtering for custom policies.
Only filters records we're 100% certain would be denied by any reasonable policy.
"""
current_user = get_authenticated_user()
if not current_user:
return "(access_attributes IS NULL OR access_attributes = 'null' OR access_attributes = '{}')"
return "1=1"

View file

@ -17,15 +17,20 @@ from sqlalchemy import (
String, String,
Table, Table,
Text, Text,
inspect,
select, select,
text,
) )
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
from llama_stack.apis.common.responses import PaginatedResponse from llama_stack.apis.common.responses import PaginatedResponse
from llama_stack.log import get_logger
from .api import ColumnDefinition, ColumnType, SqlStore from .api import ColumnDefinition, ColumnType, SqlStore
from .sqlstore import SqlAlchemySqlStoreConfig from .sqlstore import SqlAlchemySqlStoreConfig
logger = get_logger(name=__name__, category="sqlstore")
TYPE_MAPPING: dict[ColumnType, Any] = { TYPE_MAPPING: dict[ColumnType, Any] = {
ColumnType.INTEGER: Integer, ColumnType.INTEGER: Integer,
ColumnType.STRING: String, ColumnType.STRING: String,
@ -56,7 +61,7 @@ class SqlAlchemySqlStoreImpl(SqlStore):
for col_name, col_props in schema.items(): for col_name, col_props in schema.items():
col_type = None col_type = None
is_primary_key = False is_primary_key = False
is_nullable = True # Default to nullable is_nullable = True
if isinstance(col_props, ColumnType): if isinstance(col_props, ColumnType):
col_type = col_props col_type = col_props
@ -73,14 +78,11 @@ class SqlAlchemySqlStoreImpl(SqlStore):
Column(col_name, sqlalchemy_type, primary_key=is_primary_key, nullable=is_nullable) Column(col_name, sqlalchemy_type, primary_key=is_primary_key, nullable=is_nullable)
) )
# Check if table already exists in metadata, otherwise define it
if table not in self.metadata.tables: if table not in self.metadata.tables:
sqlalchemy_table = Table(table, self.metadata, *sqlalchemy_columns) sqlalchemy_table = Table(table, self.metadata, *sqlalchemy_columns)
else: else:
sqlalchemy_table = self.metadata.tables[table] sqlalchemy_table = self.metadata.tables[table]
# Create the table in the database if it doesn't exist
# checkfirst=True ensures it doesn't try to recreate if it's already there
engine = create_async_engine(self.config.engine_str) engine = create_async_engine(self.config.engine_str)
async with engine.begin() as conn: async with engine.begin() as conn:
await conn.run_sync(self.metadata.create_all, tables=[sqlalchemy_table], checkfirst=True) await conn.run_sync(self.metadata.create_all, tables=[sqlalchemy_table], checkfirst=True)
@ -94,6 +96,7 @@ class SqlAlchemySqlStoreImpl(SqlStore):
self, self,
table: str, table: str,
where: Mapping[str, Any] | None = None, where: Mapping[str, Any] | None = None,
where_sql: str | None = None,
limit: int | None = None, limit: int | None = None,
order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None, order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None,
cursor: tuple[str, str] | None = None, cursor: tuple[str, str] | None = None,
@ -106,6 +109,9 @@ class SqlAlchemySqlStoreImpl(SqlStore):
for key, value in where.items(): for key, value in where.items():
query = query.where(table_obj.c[key] == value) query = query.where(table_obj.c[key] == value)
if where_sql:
query = query.where(text(where_sql))
# Handle cursor-based pagination # Handle cursor-based pagination
if cursor: if cursor:
# Validate cursor tuple format # Validate cursor tuple format
@ -192,9 +198,10 @@ class SqlAlchemySqlStoreImpl(SqlStore):
self, self,
table: str, table: str,
where: Mapping[str, Any] | None = None, where: Mapping[str, Any] | None = None,
where_sql: str | None = None,
order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None, order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None,
) -> dict[str, Any] | None: ) -> dict[str, Any] | None:
result = await self.fetch_all(table, where, limit=1, order_by=order_by) result = await self.fetch_all(table, where, where_sql, limit=1, order_by=order_by)
if not result.data: if not result.data:
return None return None
return result.data[0] return result.data[0]
@ -225,3 +232,47 @@ class SqlAlchemySqlStoreImpl(SqlStore):
stmt = stmt.where(self.metadata.tables[table].c[key] == value) stmt = stmt.where(self.metadata.tables[table].c[key] == value)
await session.execute(stmt) await session.execute(stmt)
await session.commit() await session.commit()
async def add_column_if_not_exists(
self,
table: str,
column_name: str,
column_type: ColumnType,
nullable: bool = True,
) -> None:
"""Add a column to an existing table if the column doesn't already exist."""
engine = create_async_engine(self.config.engine_str)
try:
inspector = inspect(engine)
table_names = inspector.get_table_names()
if table not in table_names:
return
existing_columns = inspector.get_columns(table)
column_names = [col["name"] for col in existing_columns]
if column_name in column_names:
return
sqlalchemy_type = TYPE_MAPPING.get(column_type)
if not sqlalchemy_type:
raise ValueError(f"Unsupported column type '{column_type}' for column '{column_name}'.")
# Create the ALTER TABLE statement
# Note: We need to get the dialect-specific type name
dialect = engine.dialect
type_impl = sqlalchemy_type()
compiled_type = type_impl.compile(dialect=dialect)
nullable_clause = "" if nullable else " NOT NULL"
add_column_sql = text(f"ALTER TABLE {table} ADD COLUMN {column_name} {compiled_type}{nullable_clause}")
async with engine.begin() as conn:
await conn.execute(add_column_sql)
except Exception:
# If any error occurs during migration, log it but don't fail
# The table creation will handle adding the column
pass

View file

@ -39,6 +39,7 @@ from llama_stack.apis.inference.inference import (
OpenAIUserMessageParam, OpenAIUserMessageParam,
) )
from llama_stack.apis.tools.tools import Tool, ToolGroups, ToolInvocationResult, ToolParameter, ToolRuntime from llama_stack.apis.tools.tools import Tool, ToolGroups, ToolInvocationResult, ToolParameter, ToolRuntime
from llama_stack.distribution.access_control.access_control import default_policy
from llama_stack.providers.inline.agents.meta_reference.openai_responses import ( from llama_stack.providers.inline.agents.meta_reference.openai_responses import (
OpenAIResponsesImpl, OpenAIResponsesImpl,
) )
@ -599,7 +600,7 @@ async def test_responses_store_list_input_items_logic():
# Create mock store and response store # Create mock store and response store
mock_sql_store = AsyncMock() mock_sql_store = AsyncMock()
responses_store = ResponsesStore(sql_store_config=None) responses_store = ResponsesStore(sql_store_config=None, policy=default_policy())
responses_store.sql_store = mock_sql_store responses_store.sql_store = mock_sql_store
# Setup test data - multiple input items # Setup test data - multiple input items

View file

@ -47,7 +47,7 @@ async def test_inference_store_pagination_basic():
"""Test basic pagination functionality.""" """Test basic pagination functionality."""
with TemporaryDirectory() as tmp_dir: with TemporaryDirectory() as tmp_dir:
db_path = tmp_dir + "/test.db" db_path = tmp_dir + "/test.db"
store = InferenceStore(SqliteSqlStoreConfig(db_path=db_path)) store = InferenceStore(SqliteSqlStoreConfig(db_path=db_path), policy=[])
await store.initialize() await store.initialize()
# Create test data with different timestamps # Create test data with different timestamps
@ -93,7 +93,7 @@ async def test_inference_store_pagination_ascending():
"""Test pagination with ascending order.""" """Test pagination with ascending order."""
with TemporaryDirectory() as tmp_dir: with TemporaryDirectory() as tmp_dir:
db_path = tmp_dir + "/test.db" db_path = tmp_dir + "/test.db"
store = InferenceStore(SqliteSqlStoreConfig(db_path=db_path)) store = InferenceStore(SqliteSqlStoreConfig(db_path=db_path), policy=[])
await store.initialize() await store.initialize()
# Create test data # Create test data
@ -128,7 +128,7 @@ async def test_inference_store_pagination_with_model_filter():
"""Test pagination combined with model filtering.""" """Test pagination combined with model filtering."""
with TemporaryDirectory() as tmp_dir: with TemporaryDirectory() as tmp_dir:
db_path = tmp_dir + "/test.db" db_path = tmp_dir + "/test.db"
store = InferenceStore(SqliteSqlStoreConfig(db_path=db_path)) store = InferenceStore(SqliteSqlStoreConfig(db_path=db_path), policy=[])
await store.initialize() await store.initialize()
# Create test data with different models # Create test data with different models
@ -166,7 +166,7 @@ async def test_inference_store_pagination_invalid_after():
"""Test error handling for invalid 'after' parameter.""" """Test error handling for invalid 'after' parameter."""
with TemporaryDirectory() as tmp_dir: with TemporaryDirectory() as tmp_dir:
db_path = tmp_dir + "/test.db" db_path = tmp_dir + "/test.db"
store = InferenceStore(SqliteSqlStoreConfig(db_path=db_path)) store = InferenceStore(SqliteSqlStoreConfig(db_path=db_path), policy=[])
await store.initialize() await store.initialize()
# Try to paginate with non-existent ID # Try to paginate with non-existent ID
@ -179,7 +179,7 @@ async def test_inference_store_pagination_no_limit():
"""Test pagination behavior when no limit is specified.""" """Test pagination behavior when no limit is specified."""
with TemporaryDirectory() as tmp_dir: with TemporaryDirectory() as tmp_dir:
db_path = tmp_dir + "/test.db" db_path = tmp_dir + "/test.db"
store = InferenceStore(SqliteSqlStoreConfig(db_path=db_path)) store = InferenceStore(SqliteSqlStoreConfig(db_path=db_path), policy=[])
await store.initialize() await store.initialize()
# Create test data # Create test data

View file

@ -49,7 +49,7 @@ async def test_responses_store_pagination_basic():
"""Test basic pagination functionality for responses store.""" """Test basic pagination functionality for responses store."""
with TemporaryDirectory() as tmp_dir: with TemporaryDirectory() as tmp_dir:
db_path = tmp_dir + "/test.db" db_path = tmp_dir + "/test.db"
store = ResponsesStore(SqliteSqlStoreConfig(db_path=db_path)) store = ResponsesStore(SqliteSqlStoreConfig(db_path=db_path), policy=[])
await store.initialize() await store.initialize()
# Create test data with different timestamps # Create test data with different timestamps
@ -95,7 +95,7 @@ async def test_responses_store_pagination_ascending():
"""Test pagination with ascending order.""" """Test pagination with ascending order."""
with TemporaryDirectory() as tmp_dir: with TemporaryDirectory() as tmp_dir:
db_path = tmp_dir + "/test.db" db_path = tmp_dir + "/test.db"
store = ResponsesStore(SqliteSqlStoreConfig(db_path=db_path)) store = ResponsesStore(SqliteSqlStoreConfig(db_path=db_path), policy=[])
await store.initialize() await store.initialize()
# Create test data # Create test data
@ -130,7 +130,7 @@ async def test_responses_store_pagination_with_model_filter():
"""Test pagination combined with model filtering.""" """Test pagination combined with model filtering."""
with TemporaryDirectory() as tmp_dir: with TemporaryDirectory() as tmp_dir:
db_path = tmp_dir + "/test.db" db_path = tmp_dir + "/test.db"
store = ResponsesStore(SqliteSqlStoreConfig(db_path=db_path)) store = ResponsesStore(SqliteSqlStoreConfig(db_path=db_path), policy=[])
await store.initialize() await store.initialize()
# Create test data with different models # Create test data with different models
@ -168,7 +168,7 @@ async def test_responses_store_pagination_invalid_after():
"""Test error handling for invalid 'after' parameter.""" """Test error handling for invalid 'after' parameter."""
with TemporaryDirectory() as tmp_dir: with TemporaryDirectory() as tmp_dir:
db_path = tmp_dir + "/test.db" db_path = tmp_dir + "/test.db"
store = ResponsesStore(SqliteSqlStoreConfig(db_path=db_path)) store = ResponsesStore(SqliteSqlStoreConfig(db_path=db_path), policy=[])
await store.initialize() await store.initialize()
# Try to paginate with non-existent ID # Try to paginate with non-existent ID
@ -181,7 +181,7 @@ async def test_responses_store_pagination_no_limit():
"""Test pagination behavior when no limit is specified.""" """Test pagination behavior when no limit is specified."""
with TemporaryDirectory() as tmp_dir: with TemporaryDirectory() as tmp_dir:
db_path = tmp_dir + "/test.db" db_path = tmp_dir + "/test.db"
store = ResponsesStore(SqliteSqlStoreConfig(db_path=db_path)) store = ResponsesStore(SqliteSqlStoreConfig(db_path=db_path), policy=[])
await store.initialize() await store.initialize()
# Create test data # Create test data
@ -210,7 +210,7 @@ async def test_responses_store_get_response_object():
"""Test retrieving a single response object.""" """Test retrieving a single response object."""
with TemporaryDirectory() as tmp_dir: with TemporaryDirectory() as tmp_dir:
db_path = tmp_dir + "/test.db" db_path = tmp_dir + "/test.db"
store = ResponsesStore(SqliteSqlStoreConfig(db_path=db_path)) store = ResponsesStore(SqliteSqlStoreConfig(db_path=db_path), policy=[])
await store.initialize() await store.initialize()
# Store a test response # Store a test response
@ -235,7 +235,7 @@ async def test_responses_store_input_items_pagination():
"""Test pagination functionality for input items.""" """Test pagination functionality for input items."""
with TemporaryDirectory() as tmp_dir: with TemporaryDirectory() as tmp_dir:
db_path = tmp_dir + "/test.db" db_path = tmp_dir + "/test.db"
store = ResponsesStore(SqliteSqlStoreConfig(db_path=db_path)) store = ResponsesStore(SqliteSqlStoreConfig(db_path=db_path), policy=[])
await store.initialize() await store.initialize()
# Store a test response with many inputs with explicit IDs # Store a test response with many inputs with explicit IDs
@ -313,7 +313,7 @@ async def test_responses_store_input_items_before_pagination():
"""Test before pagination functionality for input items.""" """Test before pagination functionality for input items."""
with TemporaryDirectory() as tmp_dir: with TemporaryDirectory() as tmp_dir:
db_path = tmp_dir + "/test.db" db_path = tmp_dir + "/test.db"
store = ResponsesStore(SqliteSqlStoreConfig(db_path=db_path)) store = ResponsesStore(SqliteSqlStoreConfig(db_path=db_path), policy=[])
await store.initialize() await store.initialize()
# Store a test response with many inputs with explicit IDs # Store a test response with many inputs with explicit IDs

View file

@ -0,0 +1,218 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from tempfile import TemporaryDirectory
from unittest.mock import patch
import pytest
from llama_stack.distribution.access_control.access_control import default_policy, is_action_allowed
from llama_stack.distribution.access_control.datatypes import Action
from llama_stack.distribution.datatypes import User
from llama_stack.providers.utils.sqlstore.api import ColumnType
from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore, SqlRecord
from llama_stack.providers.utils.sqlstore.sqlalchemy_sqlstore import SqlAlchemySqlStoreImpl
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
@pytest.mark.asyncio
@patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user")
async def test_authorized_fetch_with_where_sql_access_control(mock_get_authenticated_user):
"""Test that fetch_all works correctly with where_sql for access control"""
with TemporaryDirectory() as tmp_dir:
db_name = "test_access_control.db"
base_sqlstore = SqlAlchemySqlStoreImpl(
SqliteSqlStoreConfig(
db_path=tmp_dir + "/" + db_name,
)
)
sqlstore = AuthorizedSqlStore(base_sqlstore)
# Create table with access control
await sqlstore.create_table(
table="documents",
schema={
"id": ColumnType.INTEGER,
"title": ColumnType.STRING,
"content": ColumnType.TEXT,
},
)
admin_user = User("admin-user", {"roles": ["admin"], "teams": ["engineering"]})
regular_user = User("regular-user", {"roles": ["user"], "teams": ["marketing"]})
# Set user attributes for creating documents
mock_get_authenticated_user.return_value = admin_user
# Insert documents with access attributes
await sqlstore.insert("documents", {"id": 1, "title": "Admin Document", "content": "This is admin content"})
# Change user attributes
mock_get_authenticated_user.return_value = regular_user
await sqlstore.insert("documents", {"id": 2, "title": "User Document", "content": "Public user content"})
# Test that access control works with where parameter
mock_get_authenticated_user.return_value = admin_user
# Admin should see both documents
result = await sqlstore.fetch_all("documents", policy=default_policy(), where={"id": 1})
assert len(result.data) == 1
assert result.data[0]["title"] == "Admin Document"
# User should only see their document
mock_get_authenticated_user.return_value = regular_user
result = await sqlstore.fetch_all("documents", policy=default_policy(), where={"id": 1})
assert len(result.data) == 0
result = await sqlstore.fetch_all("documents", policy=default_policy(), where={"id": 2})
assert len(result.data) == 1
assert result.data[0]["title"] == "User Document"
row = await sqlstore.fetch_one("documents", policy=default_policy(), where={"id": 1})
assert row is None
row = await sqlstore.fetch_one("documents", policy=default_policy(), where={"id": 2})
assert row is not None
assert row["title"] == "User Document"
@pytest.mark.asyncio
@patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user")
async def test_sql_policy_consistency(mock_get_authenticated_user):
"""Test that SQL WHERE clause logic exactly matches is_action_allowed policy logic"""
with TemporaryDirectory() as tmp_dir:
db_name = "test_consistency.db"
base_sqlstore = SqlAlchemySqlStoreImpl(
SqliteSqlStoreConfig(
db_path=tmp_dir + "/" + db_name,
)
)
sqlstore = AuthorizedSqlStore(base_sqlstore)
await sqlstore.create_table(
table="resources",
schema={
"id": ColumnType.STRING,
"name": ColumnType.STRING,
},
)
# Test scenarios with different access control patterns
test_scenarios = [
# Scenario 1: Public record (no access control)
{"id": "1", "name": "public", "access_attributes": None},
# Scenario 2: Empty access control (should be treated as public)
{"id": "2", "name": "empty", "access_attributes": {}},
# Scenario 3: Record with roles requirement
{"id": "3", "name": "admin-only", "access_attributes": {"roles": ["admin"]}},
# Scenario 4: Record with multiple attribute categories
{"id": "4", "name": "admin-ml-team", "access_attributes": {"roles": ["admin"], "teams": ["ml-team"]}},
# Scenario 5: Record with teams only (missing roles category)
{"id": "5", "name": "ml-team-only", "access_attributes": {"teams": ["ml-team"]}},
# Scenario 6: Record with roles and projects
{
"id": "6",
"name": "admin-project-x",
"access_attributes": {"roles": ["admin"], "projects": ["project-x"]},
},
]
mock_get_authenticated_user.return_value = User("test-user", {"roles": ["admin"]})
for scenario in test_scenarios:
await base_sqlstore.insert("resources", scenario)
# Test with different user configurations
user_scenarios = [
# User 1: No attributes (should only see public records)
{"principal": "user1", "attributes": None},
# User 2: Empty attributes (should only see public records)
{"principal": "user2", "attributes": {}},
# User 3: Admin role only
{"principal": "user3", "attributes": {"roles": ["admin"]}},
# User 4: ML team only
{"principal": "user4", "attributes": {"teams": ["ml-team"]}},
# User 5: Admin + ML team
{"principal": "user5", "attributes": {"roles": ["admin"], "teams": ["ml-team"]}},
# User 6: Admin + Project X
{"principal": "user6", "attributes": {"roles": ["admin"], "projects": ["project-x"]}},
# User 7: Different role (should only see public)
{"principal": "user7", "attributes": {"roles": ["viewer"]}},
]
policy = default_policy()
for user_data in user_scenarios:
user = User(principal=user_data["principal"], attributes=user_data["attributes"])
mock_get_authenticated_user.return_value = user
sql_results = await sqlstore.fetch_all("resources", policy=policy)
sql_ids = {row["id"] for row in sql_results.data}
policy_ids = set()
for scenario in test_scenarios:
sql_record = SqlRecord(
record_id=scenario["id"], table_name="resources", access_attributes=scenario["access_attributes"]
)
if is_action_allowed(policy, Action.READ, sql_record, user):
policy_ids.add(scenario["id"])
assert sql_ids == policy_ids, (
f"Consistency failure for user {user.principal} with attributes {user.attributes}:\n"
f"SQL returned: {sorted(sql_ids)}\n"
f"Policy allows: {sorted(policy_ids)}\n"
f"Difference: SQL only: {sql_ids - policy_ids}, Policy only: {policy_ids - sql_ids}"
)
@pytest.mark.asyncio
@patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user")
async def test_authorized_store_user_attribute_capture(mock_get_authenticated_user):
"""Test that user attributes are properly captured during insert"""
with TemporaryDirectory() as tmp_dir:
db_name = "test_attributes.db"
base_sqlstore = SqlAlchemySqlStoreImpl(
SqliteSqlStoreConfig(
db_path=tmp_dir + "/" + db_name,
)
)
authorized_store = AuthorizedSqlStore(base_sqlstore)
await authorized_store.create_table(
table="user_data",
schema={
"id": ColumnType.STRING,
"content": ColumnType.STRING,
},
)
mock_get_authenticated_user.return_value = User(
"user-with-attrs", {"roles": ["editor"], "teams": ["content"], "projects": ["blog"]}
)
await authorized_store.insert("user_data", {"id": "item1", "content": "User content"})
mock_get_authenticated_user.return_value = User("user-no-attrs", None)
await authorized_store.insert("user_data", {"id": "item2", "content": "Public content"})
mock_get_authenticated_user.return_value = None
await authorized_store.insert("user_data", {"id": "item3", "content": "Anonymous content"})
result = await base_sqlstore.fetch_all("user_data", order_by=[("id", "asc")])
assert len(result.data) == 3
# First item should have full attributes
assert result.data[0]["id"] == "item1"
assert result.data[0]["access_attributes"] == {"roles": ["editor"], "teams": ["content"], "projects": ["blog"]}
# Second item should have null attributes (user with no attributes)
assert result.data[1]["id"] == "item2"
assert result.data[1]["access_attributes"] is None
# Third item should have null attributes (no authenticated user)
assert result.data[2]["id"] == "item3"
assert result.data[2]["access_attributes"] is None