mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
feat: support auth attributes in inference/responses stores (#2389)
# What does this PR do? Inference/Response stores now store user attributes when inserting, and respects them when fetching. ## Test Plan pytest tests/unit/utils/test_sqlstore.py
This commit is contained in:
parent
7930c524f9
commit
d3b60507d7
12 changed files with 575 additions and 32 deletions
|
@ -335,7 +335,7 @@ async def instantiate_provider(
|
||||||
method = "get_auto_router_impl"
|
method = "get_auto_router_impl"
|
||||||
|
|
||||||
config = None
|
config = None
|
||||||
args = [provider_spec.api, deps[provider_spec.routing_table_api], deps, run_config]
|
args = [provider_spec.api, deps[provider_spec.routing_table_api], deps, run_config, policy]
|
||||||
elif isinstance(provider_spec, RoutingTableProviderSpec):
|
elif isinstance(provider_spec, RoutingTableProviderSpec):
|
||||||
method = "get_routing_table_impl"
|
method = "get_routing_table_impl"
|
||||||
|
|
||||||
|
|
|
@ -47,7 +47,7 @@ async def get_routing_table_impl(
|
||||||
|
|
||||||
|
|
||||||
async def get_auto_router_impl(
|
async def get_auto_router_impl(
|
||||||
api: Api, routing_table: RoutingTable, deps: dict[str, Any], run_config: StackRunConfig
|
api: Api, routing_table: RoutingTable, deps: dict[str, Any], run_config: StackRunConfig, policy: list[AccessRule]
|
||||||
) -> Any:
|
) -> Any:
|
||||||
from .datasets import DatasetIORouter
|
from .datasets import DatasetIORouter
|
||||||
from .eval_scoring import EvalRouter, ScoringRouter
|
from .eval_scoring import EvalRouter, ScoringRouter
|
||||||
|
@ -78,7 +78,7 @@ async def get_auto_router_impl(
|
||||||
|
|
||||||
# TODO: move pass configs to routers instead
|
# TODO: move pass configs to routers instead
|
||||||
if api == Api.inference and run_config.inference_store:
|
if api == Api.inference and run_config.inference_store:
|
||||||
inference_store = InferenceStore(run_config.inference_store)
|
inference_store = InferenceStore(run_config.inference_store, policy)
|
||||||
await inference_store.initialize()
|
await inference_store.initialize()
|
||||||
api_to_dep_impl["store"] = inference_store
|
api_to_dep_impl["store"] = inference_store
|
||||||
|
|
||||||
|
|
|
@ -78,7 +78,7 @@ class MetaReferenceAgentsImpl(Agents):
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
self.persistence_store = await kvstore_impl(self.config.persistence_store)
|
self.persistence_store = await kvstore_impl(self.config.persistence_store)
|
||||||
self.responses_store = ResponsesStore(self.config.responses_store)
|
self.responses_store = ResponsesStore(self.config.responses_store, self.policy)
|
||||||
await self.responses_store.initialize()
|
await self.responses_store.initialize()
|
||||||
self.openai_responses_impl = OpenAIResponsesImpl(
|
self.openai_responses_impl = OpenAIResponsesImpl(
|
||||||
inference_api=self.inference_api,
|
inference_api=self.inference_api,
|
||||||
|
|
|
@ -10,24 +10,27 @@ from llama_stack.apis.inference import (
|
||||||
OpenAIMessageParam,
|
OpenAIMessageParam,
|
||||||
Order,
|
Order,
|
||||||
)
|
)
|
||||||
|
from llama_stack.distribution.datatypes import AccessRule
|
||||||
from llama_stack.distribution.utils.config_dirs import RUNTIME_BASE_DIR
|
from llama_stack.distribution.utils.config_dirs import RUNTIME_BASE_DIR
|
||||||
|
|
||||||
from ..sqlstore.api import ColumnDefinition, ColumnType
|
from ..sqlstore.api import ColumnDefinition, ColumnType
|
||||||
|
from ..sqlstore.authorized_sqlstore import AuthorizedSqlStore
|
||||||
from ..sqlstore.sqlstore import SqliteSqlStoreConfig, SqlStoreConfig, sqlstore_impl
|
from ..sqlstore.sqlstore import SqliteSqlStoreConfig, SqlStoreConfig, sqlstore_impl
|
||||||
|
|
||||||
|
|
||||||
class InferenceStore:
|
class InferenceStore:
|
||||||
def __init__(self, sql_store_config: SqlStoreConfig):
|
def __init__(self, sql_store_config: SqlStoreConfig, policy: list[AccessRule]):
|
||||||
if not sql_store_config:
|
if not sql_store_config:
|
||||||
sql_store_config = SqliteSqlStoreConfig(
|
sql_store_config = SqliteSqlStoreConfig(
|
||||||
db_path=(RUNTIME_BASE_DIR / "sqlstore.db").as_posix(),
|
db_path=(RUNTIME_BASE_DIR / "sqlstore.db").as_posix(),
|
||||||
)
|
)
|
||||||
self.sql_store_config = sql_store_config
|
self.sql_store_config = sql_store_config
|
||||||
self.sql_store = None
|
self.sql_store = None
|
||||||
|
self.policy = policy
|
||||||
|
|
||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
"""Create the necessary tables if they don't exist."""
|
"""Create the necessary tables if they don't exist."""
|
||||||
self.sql_store = sqlstore_impl(self.sql_store_config)
|
self.sql_store = AuthorizedSqlStore(sqlstore_impl(self.sql_store_config))
|
||||||
await self.sql_store.create_table(
|
await self.sql_store.create_table(
|
||||||
"chat_completions",
|
"chat_completions",
|
||||||
{
|
{
|
||||||
|
@ -48,8 +51,8 @@ class InferenceStore:
|
||||||
data = chat_completion.model_dump()
|
data = chat_completion.model_dump()
|
||||||
|
|
||||||
await self.sql_store.insert(
|
await self.sql_store.insert(
|
||||||
"chat_completions",
|
table="chat_completions",
|
||||||
{
|
data={
|
||||||
"id": data["id"],
|
"id": data["id"],
|
||||||
"created": data["created"],
|
"created": data["created"],
|
||||||
"model": data["model"],
|
"model": data["model"],
|
||||||
|
@ -89,6 +92,7 @@ class InferenceStore:
|
||||||
order_by=[("created", order.value)],
|
order_by=[("created", order.value)],
|
||||||
cursor=("id", after) if after else None,
|
cursor=("id", after) if after else None,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
|
policy=self.policy,
|
||||||
)
|
)
|
||||||
|
|
||||||
data = [
|
data = [
|
||||||
|
@ -112,9 +116,17 @@ class InferenceStore:
|
||||||
if not self.sql_store:
|
if not self.sql_store:
|
||||||
raise ValueError("Inference store is not initialized")
|
raise ValueError("Inference store is not initialized")
|
||||||
|
|
||||||
row = await self.sql_store.fetch_one("chat_completions", where={"id": completion_id})
|
row = await self.sql_store.fetch_one(
|
||||||
|
table="chat_completions",
|
||||||
|
where={"id": completion_id},
|
||||||
|
policy=self.policy,
|
||||||
|
)
|
||||||
|
|
||||||
if not row:
|
if not row:
|
||||||
|
# SecureSqlStore will return None if record doesn't exist OR access is denied
|
||||||
|
# This provides security by not revealing whether the record exists
|
||||||
raise ValueError(f"Chat completion with id {completion_id} not found") from None
|
raise ValueError(f"Chat completion with id {completion_id} not found") from None
|
||||||
|
|
||||||
return OpenAICompletionWithInputMessages(
|
return OpenAICompletionWithInputMessages(
|
||||||
id=row["id"],
|
id=row["id"],
|
||||||
created=row["created"],
|
created=row["created"],
|
||||||
|
|
|
@ -13,19 +13,22 @@ from llama_stack.apis.agents.openai_responses import (
|
||||||
OpenAIResponseObject,
|
OpenAIResponseObject,
|
||||||
OpenAIResponseObjectWithInput,
|
OpenAIResponseObjectWithInput,
|
||||||
)
|
)
|
||||||
|
from llama_stack.distribution.datatypes import AccessRule
|
||||||
from llama_stack.distribution.utils.config_dirs import RUNTIME_BASE_DIR
|
from llama_stack.distribution.utils.config_dirs import RUNTIME_BASE_DIR
|
||||||
|
|
||||||
from ..sqlstore.api import ColumnDefinition, ColumnType
|
from ..sqlstore.api import ColumnDefinition, ColumnType
|
||||||
|
from ..sqlstore.authorized_sqlstore import AuthorizedSqlStore
|
||||||
from ..sqlstore.sqlstore import SqliteSqlStoreConfig, SqlStoreConfig, sqlstore_impl
|
from ..sqlstore.sqlstore import SqliteSqlStoreConfig, SqlStoreConfig, sqlstore_impl
|
||||||
|
|
||||||
|
|
||||||
class ResponsesStore:
|
class ResponsesStore:
|
||||||
def __init__(self, sql_store_config: SqlStoreConfig):
|
def __init__(self, sql_store_config: SqlStoreConfig, policy: list[AccessRule]):
|
||||||
if not sql_store_config:
|
if not sql_store_config:
|
||||||
sql_store_config = SqliteSqlStoreConfig(
|
sql_store_config = SqliteSqlStoreConfig(
|
||||||
db_path=(RUNTIME_BASE_DIR / "sqlstore.db").as_posix(),
|
db_path=(RUNTIME_BASE_DIR / "sqlstore.db").as_posix(),
|
||||||
)
|
)
|
||||||
self.sql_store = sqlstore_impl(sql_store_config)
|
self.sql_store = AuthorizedSqlStore(sqlstore_impl(sql_store_config))
|
||||||
|
self.policy = policy
|
||||||
|
|
||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
"""Create the necessary tables if they don't exist."""
|
"""Create the necessary tables if they don't exist."""
|
||||||
|
@ -83,6 +86,7 @@ class ResponsesStore:
|
||||||
order_by=[("created_at", order.value)],
|
order_by=[("created_at", order.value)],
|
||||||
cursor=("id", after) if after else None,
|
cursor=("id", after) if after else None,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
|
policy=self.policy,
|
||||||
)
|
)
|
||||||
|
|
||||||
data = [OpenAIResponseObjectWithInput(**row["response_object"]) for row in paginated_result.data]
|
data = [OpenAIResponseObjectWithInput(**row["response_object"]) for row in paginated_result.data]
|
||||||
|
@ -94,9 +98,20 @@ class ResponsesStore:
|
||||||
)
|
)
|
||||||
|
|
||||||
async def get_response_object(self, response_id: str) -> OpenAIResponseObjectWithInput:
|
async def get_response_object(self, response_id: str) -> OpenAIResponseObjectWithInput:
|
||||||
row = await self.sql_store.fetch_one("openai_responses", where={"id": response_id})
|
"""
|
||||||
|
Get a response object with automatic access control checking.
|
||||||
|
"""
|
||||||
|
row = await self.sql_store.fetch_one(
|
||||||
|
"openai_responses",
|
||||||
|
where={"id": response_id},
|
||||||
|
policy=self.policy,
|
||||||
|
)
|
||||||
|
|
||||||
if not row:
|
if not row:
|
||||||
|
# SecureSqlStore will return None if record doesn't exist OR access is denied
|
||||||
|
# This provides security by not revealing whether the record exists
|
||||||
raise ValueError(f"Response with id {response_id} not found") from None
|
raise ValueError(f"Response with id {response_id} not found") from None
|
||||||
|
|
||||||
return OpenAIResponseObjectWithInput(**row["response_object"])
|
return OpenAIResponseObjectWithInput(**row["response_object"])
|
||||||
|
|
||||||
async def list_response_input_items(
|
async def list_response_input_items(
|
||||||
|
|
|
@ -51,6 +51,7 @@ class SqlStore(Protocol):
|
||||||
self,
|
self,
|
||||||
table: str,
|
table: str,
|
||||||
where: Mapping[str, Any] | None = None,
|
where: Mapping[str, Any] | None = None,
|
||||||
|
where_sql: str | None = None,
|
||||||
limit: int | None = None,
|
limit: int | None = None,
|
||||||
order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None,
|
order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None,
|
||||||
cursor: tuple[str, str] | None = None,
|
cursor: tuple[str, str] | None = None,
|
||||||
|
@ -59,7 +60,8 @@ class SqlStore(Protocol):
|
||||||
Fetch all rows from a table with optional cursor-based pagination.
|
Fetch all rows from a table with optional cursor-based pagination.
|
||||||
|
|
||||||
:param table: The table name
|
:param table: The table name
|
||||||
:param where: WHERE conditions
|
:param where: Simple key-value WHERE conditions
|
||||||
|
:param where_sql: Raw SQL WHERE clause for complex queries
|
||||||
:param limit: Maximum number of records to return
|
:param limit: Maximum number of records to return
|
||||||
:param order_by: List of (column, order) tuples for sorting
|
:param order_by: List of (column, order) tuples for sorting
|
||||||
:param cursor: Tuple of (key_column, cursor_id) for pagination (None for first page)
|
:param cursor: Tuple of (key_column, cursor_id) for pagination (None for first page)
|
||||||
|
@ -75,6 +77,7 @@ class SqlStore(Protocol):
|
||||||
self,
|
self,
|
||||||
table: str,
|
table: str,
|
||||||
where: Mapping[str, Any] | None = None,
|
where: Mapping[str, Any] | None = None,
|
||||||
|
where_sql: str | None = None,
|
||||||
order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None,
|
order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None,
|
||||||
) -> dict[str, Any] | None:
|
) -> dict[str, Any] | None:
|
||||||
"""
|
"""
|
||||||
|
@ -102,3 +105,24 @@ class SqlStore(Protocol):
|
||||||
Delete a row from a table.
|
Delete a row from a table.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
async def add_column_if_not_exists(
|
||||||
|
self,
|
||||||
|
table: str,
|
||||||
|
column_name: str,
|
||||||
|
column_type: ColumnType,
|
||||||
|
nullable: bool = True,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Add a column to an existing table if the column doesn't already exist.
|
||||||
|
|
||||||
|
This is useful for table migrations when adding new functionality.
|
||||||
|
If the table doesn't exist, this method should do nothing.
|
||||||
|
If the column already exists, this method should do nothing.
|
||||||
|
|
||||||
|
:param table: Table name
|
||||||
|
:param column_name: Name of the column to add
|
||||||
|
:param column_type: Type of the column to add
|
||||||
|
:param nullable: Whether the column should be nullable (default: True)
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
222
llama_stack/providers/utils/sqlstore/authorized_sqlstore.py
Normal file
222
llama_stack/providers/utils/sqlstore/authorized_sqlstore.py
Normal file
|
@ -0,0 +1,222 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from collections.abc import Mapping
|
||||||
|
from typing import Any, Literal
|
||||||
|
|
||||||
|
from llama_stack.distribution.access_control.access_control import default_policy, is_action_allowed
|
||||||
|
from llama_stack.distribution.access_control.conditions import ProtectedResource
|
||||||
|
from llama_stack.distribution.access_control.datatypes import AccessRule, Action, Scope
|
||||||
|
from llama_stack.distribution.datatypes import User
|
||||||
|
from llama_stack.distribution.request_headers import get_authenticated_user
|
||||||
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
|
from .api import ColumnDefinition, ColumnType, PaginatedResponse, SqlStore
|
||||||
|
|
||||||
|
logger = get_logger(name=__name__, category="authorized_sqlstore")
|
||||||
|
|
||||||
|
# Hardcoded copy of the default policy that our SQL filtering implements
|
||||||
|
# WARNING: If default_policy() changes, this constant must be updated accordingly
|
||||||
|
# or SQL filtering will fall back to conservative mode (safe but less performant)
|
||||||
|
#
|
||||||
|
# This policy represents: "Permit all actions when user is in owners list for ALL attribute categories"
|
||||||
|
# The corresponding SQL logic is implemented in _build_default_policy_where_clause():
|
||||||
|
# - Public records (no access_attributes) are always accessible
|
||||||
|
# - Records with access_attributes require user to match ALL categories that exist in the resource
|
||||||
|
# - Missing categories in the resource are treated as "no restriction" (allow)
|
||||||
|
# - Within each category, user needs ANY matching value (OR logic)
|
||||||
|
# - Between categories, user needs ALL categories to match (AND logic)
|
||||||
|
SQL_OPTIMIZED_POLICY = [
|
||||||
|
AccessRule(
|
||||||
|
permit=Scope(actions=list(Action)),
|
||||||
|
when=["user in owners roles", "user in owners teams", "user in owners projects", "user in owners namespaces"],
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class SqlRecord(ProtectedResource):
|
||||||
|
"""Simple ProtectedResource implementation for SQL records."""
|
||||||
|
|
||||||
|
def __init__(self, record_id: str, table_name: str, access_attributes: dict[str, list[str]] | None = None):
|
||||||
|
self.type = f"sql_record::{table_name}"
|
||||||
|
self.identifier = record_id
|
||||||
|
|
||||||
|
if access_attributes:
|
||||||
|
self.owner = User(
|
||||||
|
principal="system",
|
||||||
|
attributes=access_attributes,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.owner = User(
|
||||||
|
principal="system_public",
|
||||||
|
attributes=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class AuthorizedSqlStore:
|
||||||
|
"""
|
||||||
|
Authorization layer for SqlStore that provides access control functionality.
|
||||||
|
|
||||||
|
This class composes a base SqlStore and adds authorization methods that handle
|
||||||
|
access control policies, user attribute capture, and SQL filtering optimization.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, sql_store: SqlStore):
|
||||||
|
"""
|
||||||
|
Initialize the authorization layer.
|
||||||
|
|
||||||
|
:param sql_store: Base SqlStore implementation to wrap
|
||||||
|
"""
|
||||||
|
self.sql_store = sql_store
|
||||||
|
|
||||||
|
self._validate_sql_optimized_policy()
|
||||||
|
|
||||||
|
def _validate_sql_optimized_policy(self) -> None:
|
||||||
|
"""Validate that SQL_OPTIMIZED_POLICY matches the actual default_policy().
|
||||||
|
|
||||||
|
This ensures that if default_policy() changes, we detect the mismatch and
|
||||||
|
can update our SQL filtering logic accordingly.
|
||||||
|
"""
|
||||||
|
actual_default = default_policy()
|
||||||
|
|
||||||
|
if SQL_OPTIMIZED_POLICY != actual_default:
|
||||||
|
logger.warning(
|
||||||
|
f"SQL_OPTIMIZED_POLICY does not match default_policy(). "
|
||||||
|
f"SQL filtering will use conservative mode. "
|
||||||
|
f"Expected: {SQL_OPTIMIZED_POLICY}, Got: {actual_default}",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def create_table(self, table: str, schema: Mapping[str, ColumnType | ColumnDefinition]) -> None:
|
||||||
|
"""Create a table with built-in access control support."""
|
||||||
|
await self.sql_store.add_column_if_not_exists(table, "access_attributes", ColumnType.JSON)
|
||||||
|
|
||||||
|
enhanced_schema = dict(schema)
|
||||||
|
if "access_attributes" not in enhanced_schema:
|
||||||
|
enhanced_schema["access_attributes"] = ColumnType.JSON
|
||||||
|
|
||||||
|
await self.sql_store.create_table(table, enhanced_schema)
|
||||||
|
|
||||||
|
async def insert(self, table: str, data: Mapping[str, Any]) -> None:
|
||||||
|
"""Insert a row with automatic access control attribute capture."""
|
||||||
|
enhanced_data = dict(data)
|
||||||
|
|
||||||
|
current_user = get_authenticated_user()
|
||||||
|
if current_user and current_user.attributes:
|
||||||
|
enhanced_data["access_attributes"] = current_user.attributes
|
||||||
|
else:
|
||||||
|
enhanced_data["access_attributes"] = None
|
||||||
|
|
||||||
|
await self.sql_store.insert(table, enhanced_data)
|
||||||
|
|
||||||
|
async def fetch_all(
|
||||||
|
self,
|
||||||
|
table: str,
|
||||||
|
policy: list[AccessRule],
|
||||||
|
where: Mapping[str, Any] | None = None,
|
||||||
|
limit: int | None = None,
|
||||||
|
order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None,
|
||||||
|
cursor: tuple[str, str] | None = None,
|
||||||
|
) -> PaginatedResponse:
|
||||||
|
"""Fetch all rows with automatic access control filtering."""
|
||||||
|
access_where = self._build_access_control_where_clause(policy)
|
||||||
|
rows = await self.sql_store.fetch_all(
|
||||||
|
table=table,
|
||||||
|
where=where,
|
||||||
|
where_sql=access_where,
|
||||||
|
limit=limit,
|
||||||
|
order_by=order_by,
|
||||||
|
cursor=cursor,
|
||||||
|
)
|
||||||
|
|
||||||
|
current_user = get_authenticated_user()
|
||||||
|
filtered_rows = []
|
||||||
|
|
||||||
|
for row in rows.data:
|
||||||
|
stored_access_attrs = row.get("access_attributes")
|
||||||
|
|
||||||
|
record_id = row.get("id", "unknown")
|
||||||
|
sql_record = SqlRecord(str(record_id), table, stored_access_attrs)
|
||||||
|
|
||||||
|
if is_action_allowed(policy, Action.READ, sql_record, current_user):
|
||||||
|
filtered_rows.append(row)
|
||||||
|
|
||||||
|
return PaginatedResponse(
|
||||||
|
data=filtered_rows,
|
||||||
|
has_more=rows.has_more,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def fetch_one(
|
||||||
|
self,
|
||||||
|
table: str,
|
||||||
|
policy: list[AccessRule],
|
||||||
|
where: Mapping[str, Any] | None = None,
|
||||||
|
order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None,
|
||||||
|
) -> dict[str, Any] | None:
|
||||||
|
"""Fetch one row with automatic access control checking."""
|
||||||
|
results = await self.fetch_all(
|
||||||
|
table=table,
|
||||||
|
policy=policy,
|
||||||
|
where=where,
|
||||||
|
limit=1,
|
||||||
|
order_by=order_by,
|
||||||
|
)
|
||||||
|
|
||||||
|
return results.data[0] if results.data else None
|
||||||
|
|
||||||
|
def _build_access_control_where_clause(self, policy: list[AccessRule]) -> str:
|
||||||
|
"""Build SQL WHERE clause for access control filtering.
|
||||||
|
|
||||||
|
Only applies SQL filtering for the default policy to ensure correctness.
|
||||||
|
For custom policies, uses conservative filtering to avoid blocking legitimate access.
|
||||||
|
"""
|
||||||
|
if not policy or policy == SQL_OPTIMIZED_POLICY:
|
||||||
|
return self._build_default_policy_where_clause()
|
||||||
|
else:
|
||||||
|
return self._build_conservative_where_clause()
|
||||||
|
|
||||||
|
def _build_default_policy_where_clause(self) -> str:
|
||||||
|
"""Build SQL WHERE clause for the default policy.
|
||||||
|
|
||||||
|
Default policy: permit all actions when user in owners [roles, teams, projects, namespaces]
|
||||||
|
This means user must match ALL attribute categories that exist in the resource.
|
||||||
|
"""
|
||||||
|
current_user = get_authenticated_user()
|
||||||
|
|
||||||
|
if not current_user or not current_user.attributes:
|
||||||
|
return "(access_attributes IS NULL OR access_attributes = 'null' OR access_attributes = '{}')"
|
||||||
|
else:
|
||||||
|
base_conditions = ["access_attributes IS NULL", "access_attributes = 'null'", "access_attributes = '{}'"]
|
||||||
|
|
||||||
|
user_attr_conditions = []
|
||||||
|
|
||||||
|
for attr_key, user_values in current_user.attributes.items():
|
||||||
|
if user_values:
|
||||||
|
value_conditions = []
|
||||||
|
for value in user_values:
|
||||||
|
value_conditions.append(f"JSON_EXTRACT(access_attributes, '$.{attr_key}') LIKE '%\"{value}\"%'")
|
||||||
|
|
||||||
|
if value_conditions:
|
||||||
|
category_missing = f"JSON_EXTRACT(access_attributes, '$.{attr_key}') IS NULL"
|
||||||
|
user_matches_category = f"({' OR '.join(value_conditions)})"
|
||||||
|
user_attr_conditions.append(f"({category_missing} OR {user_matches_category})")
|
||||||
|
|
||||||
|
if user_attr_conditions:
|
||||||
|
all_requirements_met = f"({' AND '.join(user_attr_conditions)})"
|
||||||
|
base_conditions.append(all_requirements_met)
|
||||||
|
return f"({' OR '.join(base_conditions)})"
|
||||||
|
else:
|
||||||
|
return f"({' OR '.join(base_conditions)})"
|
||||||
|
|
||||||
|
def _build_conservative_where_clause(self) -> str:
|
||||||
|
"""Conservative SQL filtering for custom policies.
|
||||||
|
|
||||||
|
Only filters records we're 100% certain would be denied by any reasonable policy.
|
||||||
|
"""
|
||||||
|
current_user = get_authenticated_user()
|
||||||
|
|
||||||
|
if not current_user:
|
||||||
|
return "(access_attributes IS NULL OR access_attributes = 'null' OR access_attributes = '{}')"
|
||||||
|
return "1=1"
|
|
@ -17,15 +17,20 @@ from sqlalchemy import (
|
||||||
String,
|
String,
|
||||||
Table,
|
Table,
|
||||||
Text,
|
Text,
|
||||||
|
inspect,
|
||||||
select,
|
select,
|
||||||
|
text,
|
||||||
)
|
)
|
||||||
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
||||||
|
|
||||||
from llama_stack.apis.common.responses import PaginatedResponse
|
from llama_stack.apis.common.responses import PaginatedResponse
|
||||||
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
from .api import ColumnDefinition, ColumnType, SqlStore
|
from .api import ColumnDefinition, ColumnType, SqlStore
|
||||||
from .sqlstore import SqlAlchemySqlStoreConfig
|
from .sqlstore import SqlAlchemySqlStoreConfig
|
||||||
|
|
||||||
|
logger = get_logger(name=__name__, category="sqlstore")
|
||||||
|
|
||||||
TYPE_MAPPING: dict[ColumnType, Any] = {
|
TYPE_MAPPING: dict[ColumnType, Any] = {
|
||||||
ColumnType.INTEGER: Integer,
|
ColumnType.INTEGER: Integer,
|
||||||
ColumnType.STRING: String,
|
ColumnType.STRING: String,
|
||||||
|
@ -56,7 +61,7 @@ class SqlAlchemySqlStoreImpl(SqlStore):
|
||||||
for col_name, col_props in schema.items():
|
for col_name, col_props in schema.items():
|
||||||
col_type = None
|
col_type = None
|
||||||
is_primary_key = False
|
is_primary_key = False
|
||||||
is_nullable = True # Default to nullable
|
is_nullable = True
|
||||||
|
|
||||||
if isinstance(col_props, ColumnType):
|
if isinstance(col_props, ColumnType):
|
||||||
col_type = col_props
|
col_type = col_props
|
||||||
|
@ -73,14 +78,11 @@ class SqlAlchemySqlStoreImpl(SqlStore):
|
||||||
Column(col_name, sqlalchemy_type, primary_key=is_primary_key, nullable=is_nullable)
|
Column(col_name, sqlalchemy_type, primary_key=is_primary_key, nullable=is_nullable)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check if table already exists in metadata, otherwise define it
|
|
||||||
if table not in self.metadata.tables:
|
if table not in self.metadata.tables:
|
||||||
sqlalchemy_table = Table(table, self.metadata, *sqlalchemy_columns)
|
sqlalchemy_table = Table(table, self.metadata, *sqlalchemy_columns)
|
||||||
else:
|
else:
|
||||||
sqlalchemy_table = self.metadata.tables[table]
|
sqlalchemy_table = self.metadata.tables[table]
|
||||||
|
|
||||||
# Create the table in the database if it doesn't exist
|
|
||||||
# checkfirst=True ensures it doesn't try to recreate if it's already there
|
|
||||||
engine = create_async_engine(self.config.engine_str)
|
engine = create_async_engine(self.config.engine_str)
|
||||||
async with engine.begin() as conn:
|
async with engine.begin() as conn:
|
||||||
await conn.run_sync(self.metadata.create_all, tables=[sqlalchemy_table], checkfirst=True)
|
await conn.run_sync(self.metadata.create_all, tables=[sqlalchemy_table], checkfirst=True)
|
||||||
|
@ -94,6 +96,7 @@ class SqlAlchemySqlStoreImpl(SqlStore):
|
||||||
self,
|
self,
|
||||||
table: str,
|
table: str,
|
||||||
where: Mapping[str, Any] | None = None,
|
where: Mapping[str, Any] | None = None,
|
||||||
|
where_sql: str | None = None,
|
||||||
limit: int | None = None,
|
limit: int | None = None,
|
||||||
order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None,
|
order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None,
|
||||||
cursor: tuple[str, str] | None = None,
|
cursor: tuple[str, str] | None = None,
|
||||||
|
@ -106,6 +109,9 @@ class SqlAlchemySqlStoreImpl(SqlStore):
|
||||||
for key, value in where.items():
|
for key, value in where.items():
|
||||||
query = query.where(table_obj.c[key] == value)
|
query = query.where(table_obj.c[key] == value)
|
||||||
|
|
||||||
|
if where_sql:
|
||||||
|
query = query.where(text(where_sql))
|
||||||
|
|
||||||
# Handle cursor-based pagination
|
# Handle cursor-based pagination
|
||||||
if cursor:
|
if cursor:
|
||||||
# Validate cursor tuple format
|
# Validate cursor tuple format
|
||||||
|
@ -192,9 +198,10 @@ class SqlAlchemySqlStoreImpl(SqlStore):
|
||||||
self,
|
self,
|
||||||
table: str,
|
table: str,
|
||||||
where: Mapping[str, Any] | None = None,
|
where: Mapping[str, Any] | None = None,
|
||||||
|
where_sql: str | None = None,
|
||||||
order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None,
|
order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None,
|
||||||
) -> dict[str, Any] | None:
|
) -> dict[str, Any] | None:
|
||||||
result = await self.fetch_all(table, where, limit=1, order_by=order_by)
|
result = await self.fetch_all(table, where, where_sql, limit=1, order_by=order_by)
|
||||||
if not result.data:
|
if not result.data:
|
||||||
return None
|
return None
|
||||||
return result.data[0]
|
return result.data[0]
|
||||||
|
@ -225,3 +232,47 @@ class SqlAlchemySqlStoreImpl(SqlStore):
|
||||||
stmt = stmt.where(self.metadata.tables[table].c[key] == value)
|
stmt = stmt.where(self.metadata.tables[table].c[key] == value)
|
||||||
await session.execute(stmt)
|
await session.execute(stmt)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
|
||||||
|
async def add_column_if_not_exists(
|
||||||
|
self,
|
||||||
|
table: str,
|
||||||
|
column_name: str,
|
||||||
|
column_type: ColumnType,
|
||||||
|
nullable: bool = True,
|
||||||
|
) -> None:
|
||||||
|
"""Add a column to an existing table if the column doesn't already exist."""
|
||||||
|
engine = create_async_engine(self.config.engine_str)
|
||||||
|
|
||||||
|
try:
|
||||||
|
inspector = inspect(engine)
|
||||||
|
|
||||||
|
table_names = inspector.get_table_names()
|
||||||
|
if table not in table_names:
|
||||||
|
return
|
||||||
|
|
||||||
|
existing_columns = inspector.get_columns(table)
|
||||||
|
column_names = [col["name"] for col in existing_columns]
|
||||||
|
|
||||||
|
if column_name in column_names:
|
||||||
|
return
|
||||||
|
|
||||||
|
sqlalchemy_type = TYPE_MAPPING.get(column_type)
|
||||||
|
if not sqlalchemy_type:
|
||||||
|
raise ValueError(f"Unsupported column type '{column_type}' for column '{column_name}'.")
|
||||||
|
|
||||||
|
# Create the ALTER TABLE statement
|
||||||
|
# Note: We need to get the dialect-specific type name
|
||||||
|
dialect = engine.dialect
|
||||||
|
type_impl = sqlalchemy_type()
|
||||||
|
compiled_type = type_impl.compile(dialect=dialect)
|
||||||
|
|
||||||
|
nullable_clause = "" if nullable else " NOT NULL"
|
||||||
|
add_column_sql = text(f"ALTER TABLE {table} ADD COLUMN {column_name} {compiled_type}{nullable_clause}")
|
||||||
|
|
||||||
|
async with engine.begin() as conn:
|
||||||
|
await conn.execute(add_column_sql)
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
# If any error occurs during migration, log it but don't fail
|
||||||
|
# The table creation will handle adding the column
|
||||||
|
pass
|
||||||
|
|
|
@ -39,6 +39,7 @@ from llama_stack.apis.inference.inference import (
|
||||||
OpenAIUserMessageParam,
|
OpenAIUserMessageParam,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.tools.tools import Tool, ToolGroups, ToolInvocationResult, ToolParameter, ToolRuntime
|
from llama_stack.apis.tools.tools import Tool, ToolGroups, ToolInvocationResult, ToolParameter, ToolRuntime
|
||||||
|
from llama_stack.distribution.access_control.access_control import default_policy
|
||||||
from llama_stack.providers.inline.agents.meta_reference.openai_responses import (
|
from llama_stack.providers.inline.agents.meta_reference.openai_responses import (
|
||||||
OpenAIResponsesImpl,
|
OpenAIResponsesImpl,
|
||||||
)
|
)
|
||||||
|
@ -599,7 +600,7 @@ async def test_responses_store_list_input_items_logic():
|
||||||
|
|
||||||
# Create mock store and response store
|
# Create mock store and response store
|
||||||
mock_sql_store = AsyncMock()
|
mock_sql_store = AsyncMock()
|
||||||
responses_store = ResponsesStore(sql_store_config=None)
|
responses_store = ResponsesStore(sql_store_config=None, policy=default_policy())
|
||||||
responses_store.sql_store = mock_sql_store
|
responses_store.sql_store = mock_sql_store
|
||||||
|
|
||||||
# Setup test data - multiple input items
|
# Setup test data - multiple input items
|
||||||
|
|
|
@ -47,7 +47,7 @@ async def test_inference_store_pagination_basic():
|
||||||
"""Test basic pagination functionality."""
|
"""Test basic pagination functionality."""
|
||||||
with TemporaryDirectory() as tmp_dir:
|
with TemporaryDirectory() as tmp_dir:
|
||||||
db_path = tmp_dir + "/test.db"
|
db_path = tmp_dir + "/test.db"
|
||||||
store = InferenceStore(SqliteSqlStoreConfig(db_path=db_path))
|
store = InferenceStore(SqliteSqlStoreConfig(db_path=db_path), policy=[])
|
||||||
await store.initialize()
|
await store.initialize()
|
||||||
|
|
||||||
# Create test data with different timestamps
|
# Create test data with different timestamps
|
||||||
|
@ -93,7 +93,7 @@ async def test_inference_store_pagination_ascending():
|
||||||
"""Test pagination with ascending order."""
|
"""Test pagination with ascending order."""
|
||||||
with TemporaryDirectory() as tmp_dir:
|
with TemporaryDirectory() as tmp_dir:
|
||||||
db_path = tmp_dir + "/test.db"
|
db_path = tmp_dir + "/test.db"
|
||||||
store = InferenceStore(SqliteSqlStoreConfig(db_path=db_path))
|
store = InferenceStore(SqliteSqlStoreConfig(db_path=db_path), policy=[])
|
||||||
await store.initialize()
|
await store.initialize()
|
||||||
|
|
||||||
# Create test data
|
# Create test data
|
||||||
|
@ -128,7 +128,7 @@ async def test_inference_store_pagination_with_model_filter():
|
||||||
"""Test pagination combined with model filtering."""
|
"""Test pagination combined with model filtering."""
|
||||||
with TemporaryDirectory() as tmp_dir:
|
with TemporaryDirectory() as tmp_dir:
|
||||||
db_path = tmp_dir + "/test.db"
|
db_path = tmp_dir + "/test.db"
|
||||||
store = InferenceStore(SqliteSqlStoreConfig(db_path=db_path))
|
store = InferenceStore(SqliteSqlStoreConfig(db_path=db_path), policy=[])
|
||||||
await store.initialize()
|
await store.initialize()
|
||||||
|
|
||||||
# Create test data with different models
|
# Create test data with different models
|
||||||
|
@ -166,7 +166,7 @@ async def test_inference_store_pagination_invalid_after():
|
||||||
"""Test error handling for invalid 'after' parameter."""
|
"""Test error handling for invalid 'after' parameter."""
|
||||||
with TemporaryDirectory() as tmp_dir:
|
with TemporaryDirectory() as tmp_dir:
|
||||||
db_path = tmp_dir + "/test.db"
|
db_path = tmp_dir + "/test.db"
|
||||||
store = InferenceStore(SqliteSqlStoreConfig(db_path=db_path))
|
store = InferenceStore(SqliteSqlStoreConfig(db_path=db_path), policy=[])
|
||||||
await store.initialize()
|
await store.initialize()
|
||||||
|
|
||||||
# Try to paginate with non-existent ID
|
# Try to paginate with non-existent ID
|
||||||
|
@ -179,7 +179,7 @@ async def test_inference_store_pagination_no_limit():
|
||||||
"""Test pagination behavior when no limit is specified."""
|
"""Test pagination behavior when no limit is specified."""
|
||||||
with TemporaryDirectory() as tmp_dir:
|
with TemporaryDirectory() as tmp_dir:
|
||||||
db_path = tmp_dir + "/test.db"
|
db_path = tmp_dir + "/test.db"
|
||||||
store = InferenceStore(SqliteSqlStoreConfig(db_path=db_path))
|
store = InferenceStore(SqliteSqlStoreConfig(db_path=db_path), policy=[])
|
||||||
await store.initialize()
|
await store.initialize()
|
||||||
|
|
||||||
# Create test data
|
# Create test data
|
||||||
|
|
|
@ -49,7 +49,7 @@ async def test_responses_store_pagination_basic():
|
||||||
"""Test basic pagination functionality for responses store."""
|
"""Test basic pagination functionality for responses store."""
|
||||||
with TemporaryDirectory() as tmp_dir:
|
with TemporaryDirectory() as tmp_dir:
|
||||||
db_path = tmp_dir + "/test.db"
|
db_path = tmp_dir + "/test.db"
|
||||||
store = ResponsesStore(SqliteSqlStoreConfig(db_path=db_path))
|
store = ResponsesStore(SqliteSqlStoreConfig(db_path=db_path), policy=[])
|
||||||
await store.initialize()
|
await store.initialize()
|
||||||
|
|
||||||
# Create test data with different timestamps
|
# Create test data with different timestamps
|
||||||
|
@ -95,7 +95,7 @@ async def test_responses_store_pagination_ascending():
|
||||||
"""Test pagination with ascending order."""
|
"""Test pagination with ascending order."""
|
||||||
with TemporaryDirectory() as tmp_dir:
|
with TemporaryDirectory() as tmp_dir:
|
||||||
db_path = tmp_dir + "/test.db"
|
db_path = tmp_dir + "/test.db"
|
||||||
store = ResponsesStore(SqliteSqlStoreConfig(db_path=db_path))
|
store = ResponsesStore(SqliteSqlStoreConfig(db_path=db_path), policy=[])
|
||||||
await store.initialize()
|
await store.initialize()
|
||||||
|
|
||||||
# Create test data
|
# Create test data
|
||||||
|
@ -130,7 +130,7 @@ async def test_responses_store_pagination_with_model_filter():
|
||||||
"""Test pagination combined with model filtering."""
|
"""Test pagination combined with model filtering."""
|
||||||
with TemporaryDirectory() as tmp_dir:
|
with TemporaryDirectory() as tmp_dir:
|
||||||
db_path = tmp_dir + "/test.db"
|
db_path = tmp_dir + "/test.db"
|
||||||
store = ResponsesStore(SqliteSqlStoreConfig(db_path=db_path))
|
store = ResponsesStore(SqliteSqlStoreConfig(db_path=db_path), policy=[])
|
||||||
await store.initialize()
|
await store.initialize()
|
||||||
|
|
||||||
# Create test data with different models
|
# Create test data with different models
|
||||||
|
@ -168,7 +168,7 @@ async def test_responses_store_pagination_invalid_after():
|
||||||
"""Test error handling for invalid 'after' parameter."""
|
"""Test error handling for invalid 'after' parameter."""
|
||||||
with TemporaryDirectory() as tmp_dir:
|
with TemporaryDirectory() as tmp_dir:
|
||||||
db_path = tmp_dir + "/test.db"
|
db_path = tmp_dir + "/test.db"
|
||||||
store = ResponsesStore(SqliteSqlStoreConfig(db_path=db_path))
|
store = ResponsesStore(SqliteSqlStoreConfig(db_path=db_path), policy=[])
|
||||||
await store.initialize()
|
await store.initialize()
|
||||||
|
|
||||||
# Try to paginate with non-existent ID
|
# Try to paginate with non-existent ID
|
||||||
|
@ -181,7 +181,7 @@ async def test_responses_store_pagination_no_limit():
|
||||||
"""Test pagination behavior when no limit is specified."""
|
"""Test pagination behavior when no limit is specified."""
|
||||||
with TemporaryDirectory() as tmp_dir:
|
with TemporaryDirectory() as tmp_dir:
|
||||||
db_path = tmp_dir + "/test.db"
|
db_path = tmp_dir + "/test.db"
|
||||||
store = ResponsesStore(SqliteSqlStoreConfig(db_path=db_path))
|
store = ResponsesStore(SqliteSqlStoreConfig(db_path=db_path), policy=[])
|
||||||
await store.initialize()
|
await store.initialize()
|
||||||
|
|
||||||
# Create test data
|
# Create test data
|
||||||
|
@ -210,7 +210,7 @@ async def test_responses_store_get_response_object():
|
||||||
"""Test retrieving a single response object."""
|
"""Test retrieving a single response object."""
|
||||||
with TemporaryDirectory() as tmp_dir:
|
with TemporaryDirectory() as tmp_dir:
|
||||||
db_path = tmp_dir + "/test.db"
|
db_path = tmp_dir + "/test.db"
|
||||||
store = ResponsesStore(SqliteSqlStoreConfig(db_path=db_path))
|
store = ResponsesStore(SqliteSqlStoreConfig(db_path=db_path), policy=[])
|
||||||
await store.initialize()
|
await store.initialize()
|
||||||
|
|
||||||
# Store a test response
|
# Store a test response
|
||||||
|
@ -235,7 +235,7 @@ async def test_responses_store_input_items_pagination():
|
||||||
"""Test pagination functionality for input items."""
|
"""Test pagination functionality for input items."""
|
||||||
with TemporaryDirectory() as tmp_dir:
|
with TemporaryDirectory() as tmp_dir:
|
||||||
db_path = tmp_dir + "/test.db"
|
db_path = tmp_dir + "/test.db"
|
||||||
store = ResponsesStore(SqliteSqlStoreConfig(db_path=db_path))
|
store = ResponsesStore(SqliteSqlStoreConfig(db_path=db_path), policy=[])
|
||||||
await store.initialize()
|
await store.initialize()
|
||||||
|
|
||||||
# Store a test response with many inputs with explicit IDs
|
# Store a test response with many inputs with explicit IDs
|
||||||
|
@ -313,7 +313,7 @@ async def test_responses_store_input_items_before_pagination():
|
||||||
"""Test before pagination functionality for input items."""
|
"""Test before pagination functionality for input items."""
|
||||||
with TemporaryDirectory() as tmp_dir:
|
with TemporaryDirectory() as tmp_dir:
|
||||||
db_path = tmp_dir + "/test.db"
|
db_path = tmp_dir + "/test.db"
|
||||||
store = ResponsesStore(SqliteSqlStoreConfig(db_path=db_path))
|
store = ResponsesStore(SqliteSqlStoreConfig(db_path=db_path), policy=[])
|
||||||
await store.initialize()
|
await store.initialize()
|
||||||
|
|
||||||
# Store a test response with many inputs with explicit IDs
|
# Store a test response with many inputs with explicit IDs
|
||||||
|
|
218
tests/unit/utils/test_authorized_sqlstore.py
Normal file
218
tests/unit/utils/test_authorized_sqlstore.py
Normal file
|
@ -0,0 +1,218 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from tempfile import TemporaryDirectory
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from llama_stack.distribution.access_control.access_control import default_policy, is_action_allowed
|
||||||
|
from llama_stack.distribution.access_control.datatypes import Action
|
||||||
|
from llama_stack.distribution.datatypes import User
|
||||||
|
from llama_stack.providers.utils.sqlstore.api import ColumnType
|
||||||
|
from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore, SqlRecord
|
||||||
|
from llama_stack.providers.utils.sqlstore.sqlalchemy_sqlstore import SqlAlchemySqlStoreImpl
|
||||||
|
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user")
|
||||||
|
async def test_authorized_fetch_with_where_sql_access_control(mock_get_authenticated_user):
|
||||||
|
"""Test that fetch_all works correctly with where_sql for access control"""
|
||||||
|
with TemporaryDirectory() as tmp_dir:
|
||||||
|
db_name = "test_access_control.db"
|
||||||
|
base_sqlstore = SqlAlchemySqlStoreImpl(
|
||||||
|
SqliteSqlStoreConfig(
|
||||||
|
db_path=tmp_dir + "/" + db_name,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
sqlstore = AuthorizedSqlStore(base_sqlstore)
|
||||||
|
|
||||||
|
# Create table with access control
|
||||||
|
await sqlstore.create_table(
|
||||||
|
table="documents",
|
||||||
|
schema={
|
||||||
|
"id": ColumnType.INTEGER,
|
||||||
|
"title": ColumnType.STRING,
|
||||||
|
"content": ColumnType.TEXT,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
admin_user = User("admin-user", {"roles": ["admin"], "teams": ["engineering"]})
|
||||||
|
regular_user = User("regular-user", {"roles": ["user"], "teams": ["marketing"]})
|
||||||
|
|
||||||
|
# Set user attributes for creating documents
|
||||||
|
mock_get_authenticated_user.return_value = admin_user
|
||||||
|
|
||||||
|
# Insert documents with access attributes
|
||||||
|
await sqlstore.insert("documents", {"id": 1, "title": "Admin Document", "content": "This is admin content"})
|
||||||
|
|
||||||
|
# Change user attributes
|
||||||
|
mock_get_authenticated_user.return_value = regular_user
|
||||||
|
|
||||||
|
await sqlstore.insert("documents", {"id": 2, "title": "User Document", "content": "Public user content"})
|
||||||
|
|
||||||
|
# Test that access control works with where parameter
|
||||||
|
mock_get_authenticated_user.return_value = admin_user
|
||||||
|
|
||||||
|
# Admin should see both documents
|
||||||
|
result = await sqlstore.fetch_all("documents", policy=default_policy(), where={"id": 1})
|
||||||
|
assert len(result.data) == 1
|
||||||
|
assert result.data[0]["title"] == "Admin Document"
|
||||||
|
|
||||||
|
# User should only see their document
|
||||||
|
mock_get_authenticated_user.return_value = regular_user
|
||||||
|
|
||||||
|
result = await sqlstore.fetch_all("documents", policy=default_policy(), where={"id": 1})
|
||||||
|
assert len(result.data) == 0
|
||||||
|
|
||||||
|
result = await sqlstore.fetch_all("documents", policy=default_policy(), where={"id": 2})
|
||||||
|
assert len(result.data) == 1
|
||||||
|
assert result.data[0]["title"] == "User Document"
|
||||||
|
|
||||||
|
row = await sqlstore.fetch_one("documents", policy=default_policy(), where={"id": 1})
|
||||||
|
assert row is None
|
||||||
|
|
||||||
|
row = await sqlstore.fetch_one("documents", policy=default_policy(), where={"id": 2})
|
||||||
|
assert row is not None
|
||||||
|
assert row["title"] == "User Document"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user")
|
||||||
|
async def test_sql_policy_consistency(mock_get_authenticated_user):
|
||||||
|
"""Test that SQL WHERE clause logic exactly matches is_action_allowed policy logic"""
|
||||||
|
with TemporaryDirectory() as tmp_dir:
|
||||||
|
db_name = "test_consistency.db"
|
||||||
|
base_sqlstore = SqlAlchemySqlStoreImpl(
|
||||||
|
SqliteSqlStoreConfig(
|
||||||
|
db_path=tmp_dir + "/" + db_name,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
sqlstore = AuthorizedSqlStore(base_sqlstore)
|
||||||
|
|
||||||
|
await sqlstore.create_table(
|
||||||
|
table="resources",
|
||||||
|
schema={
|
||||||
|
"id": ColumnType.STRING,
|
||||||
|
"name": ColumnType.STRING,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test scenarios with different access control patterns
|
||||||
|
test_scenarios = [
|
||||||
|
# Scenario 1: Public record (no access control)
|
||||||
|
{"id": "1", "name": "public", "access_attributes": None},
|
||||||
|
# Scenario 2: Empty access control (should be treated as public)
|
||||||
|
{"id": "2", "name": "empty", "access_attributes": {}},
|
||||||
|
# Scenario 3: Record with roles requirement
|
||||||
|
{"id": "3", "name": "admin-only", "access_attributes": {"roles": ["admin"]}},
|
||||||
|
# Scenario 4: Record with multiple attribute categories
|
||||||
|
{"id": "4", "name": "admin-ml-team", "access_attributes": {"roles": ["admin"], "teams": ["ml-team"]}},
|
||||||
|
# Scenario 5: Record with teams only (missing roles category)
|
||||||
|
{"id": "5", "name": "ml-team-only", "access_attributes": {"teams": ["ml-team"]}},
|
||||||
|
# Scenario 6: Record with roles and projects
|
||||||
|
{
|
||||||
|
"id": "6",
|
||||||
|
"name": "admin-project-x",
|
||||||
|
"access_attributes": {"roles": ["admin"], "projects": ["project-x"]},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
mock_get_authenticated_user.return_value = User("test-user", {"roles": ["admin"]})
|
||||||
|
for scenario in test_scenarios:
|
||||||
|
await base_sqlstore.insert("resources", scenario)
|
||||||
|
|
||||||
|
# Test with different user configurations
|
||||||
|
user_scenarios = [
|
||||||
|
# User 1: No attributes (should only see public records)
|
||||||
|
{"principal": "user1", "attributes": None},
|
||||||
|
# User 2: Empty attributes (should only see public records)
|
||||||
|
{"principal": "user2", "attributes": {}},
|
||||||
|
# User 3: Admin role only
|
||||||
|
{"principal": "user3", "attributes": {"roles": ["admin"]}},
|
||||||
|
# User 4: ML team only
|
||||||
|
{"principal": "user4", "attributes": {"teams": ["ml-team"]}},
|
||||||
|
# User 5: Admin + ML team
|
||||||
|
{"principal": "user5", "attributes": {"roles": ["admin"], "teams": ["ml-team"]}},
|
||||||
|
# User 6: Admin + Project X
|
||||||
|
{"principal": "user6", "attributes": {"roles": ["admin"], "projects": ["project-x"]}},
|
||||||
|
# User 7: Different role (should only see public)
|
||||||
|
{"principal": "user7", "attributes": {"roles": ["viewer"]}},
|
||||||
|
]
|
||||||
|
|
||||||
|
policy = default_policy()
|
||||||
|
|
||||||
|
for user_data in user_scenarios:
|
||||||
|
user = User(principal=user_data["principal"], attributes=user_data["attributes"])
|
||||||
|
mock_get_authenticated_user.return_value = user
|
||||||
|
|
||||||
|
sql_results = await sqlstore.fetch_all("resources", policy=policy)
|
||||||
|
sql_ids = {row["id"] for row in sql_results.data}
|
||||||
|
policy_ids = set()
|
||||||
|
for scenario in test_scenarios:
|
||||||
|
sql_record = SqlRecord(
|
||||||
|
record_id=scenario["id"], table_name="resources", access_attributes=scenario["access_attributes"]
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_action_allowed(policy, Action.READ, sql_record, user):
|
||||||
|
policy_ids.add(scenario["id"])
|
||||||
|
assert sql_ids == policy_ids, (
|
||||||
|
f"Consistency failure for user {user.principal} with attributes {user.attributes}:\n"
|
||||||
|
f"SQL returned: {sorted(sql_ids)}\n"
|
||||||
|
f"Policy allows: {sorted(policy_ids)}\n"
|
||||||
|
f"Difference: SQL only: {sql_ids - policy_ids}, Policy only: {policy_ids - sql_ids}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user")
|
||||||
|
async def test_authorized_store_user_attribute_capture(mock_get_authenticated_user):
|
||||||
|
"""Test that user attributes are properly captured during insert"""
|
||||||
|
with TemporaryDirectory() as tmp_dir:
|
||||||
|
db_name = "test_attributes.db"
|
||||||
|
base_sqlstore = SqlAlchemySqlStoreImpl(
|
||||||
|
SqliteSqlStoreConfig(
|
||||||
|
db_path=tmp_dir + "/" + db_name,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
authorized_store = AuthorizedSqlStore(base_sqlstore)
|
||||||
|
|
||||||
|
await authorized_store.create_table(
|
||||||
|
table="user_data",
|
||||||
|
schema={
|
||||||
|
"id": ColumnType.STRING,
|
||||||
|
"content": ColumnType.STRING,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_get_authenticated_user.return_value = User(
|
||||||
|
"user-with-attrs", {"roles": ["editor"], "teams": ["content"], "projects": ["blog"]}
|
||||||
|
)
|
||||||
|
|
||||||
|
await authorized_store.insert("user_data", {"id": "item1", "content": "User content"})
|
||||||
|
|
||||||
|
mock_get_authenticated_user.return_value = User("user-no-attrs", None)
|
||||||
|
|
||||||
|
await authorized_store.insert("user_data", {"id": "item2", "content": "Public content"})
|
||||||
|
|
||||||
|
mock_get_authenticated_user.return_value = None
|
||||||
|
|
||||||
|
await authorized_store.insert("user_data", {"id": "item3", "content": "Anonymous content"})
|
||||||
|
result = await base_sqlstore.fetch_all("user_data", order_by=[("id", "asc")])
|
||||||
|
assert len(result.data) == 3
|
||||||
|
|
||||||
|
# First item should have full attributes
|
||||||
|
assert result.data[0]["id"] == "item1"
|
||||||
|
assert result.data[0]["access_attributes"] == {"roles": ["editor"], "teams": ["content"], "projects": ["blog"]}
|
||||||
|
|
||||||
|
# Second item should have null attributes (user with no attributes)
|
||||||
|
assert result.data[1]["id"] == "item2"
|
||||||
|
assert result.data[1]["access_attributes"] is None
|
||||||
|
|
||||||
|
# Third item should have null attributes (no authenticated user)
|
||||||
|
assert result.data[2]["id"] == "item3"
|
||||||
|
assert result.data[2]["access_attributes"] is None
|
Loading…
Add table
Add a link
Reference in a new issue