mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
feat: support auth attributes in inference/responses stores (#2389)
# What does this PR do? Inference/Response stores now store user attributes when inserting, and respects them when fetching. ## Test Plan pytest tests/unit/utils/test_sqlstore.py
This commit is contained in:
parent
7930c524f9
commit
d3b60507d7
12 changed files with 575 additions and 32 deletions
|
@ -39,6 +39,7 @@ from llama_stack.apis.inference.inference import (
|
|||
OpenAIUserMessageParam,
|
||||
)
|
||||
from llama_stack.apis.tools.tools import Tool, ToolGroups, ToolInvocationResult, ToolParameter, ToolRuntime
|
||||
from llama_stack.distribution.access_control.access_control import default_policy
|
||||
from llama_stack.providers.inline.agents.meta_reference.openai_responses import (
|
||||
OpenAIResponsesImpl,
|
||||
)
|
||||
|
@ -599,7 +600,7 @@ async def test_responses_store_list_input_items_logic():
|
|||
|
||||
# Create mock store and response store
|
||||
mock_sql_store = AsyncMock()
|
||||
responses_store = ResponsesStore(sql_store_config=None)
|
||||
responses_store = ResponsesStore(sql_store_config=None, policy=default_policy())
|
||||
responses_store.sql_store = mock_sql_store
|
||||
|
||||
# Setup test data - multiple input items
|
||||
|
|
|
@ -47,7 +47,7 @@ async def test_inference_store_pagination_basic():
|
|||
"""Test basic pagination functionality."""
|
||||
with TemporaryDirectory() as tmp_dir:
|
||||
db_path = tmp_dir + "/test.db"
|
||||
store = InferenceStore(SqliteSqlStoreConfig(db_path=db_path))
|
||||
store = InferenceStore(SqliteSqlStoreConfig(db_path=db_path), policy=[])
|
||||
await store.initialize()
|
||||
|
||||
# Create test data with different timestamps
|
||||
|
@ -93,7 +93,7 @@ async def test_inference_store_pagination_ascending():
|
|||
"""Test pagination with ascending order."""
|
||||
with TemporaryDirectory() as tmp_dir:
|
||||
db_path = tmp_dir + "/test.db"
|
||||
store = InferenceStore(SqliteSqlStoreConfig(db_path=db_path))
|
||||
store = InferenceStore(SqliteSqlStoreConfig(db_path=db_path), policy=[])
|
||||
await store.initialize()
|
||||
|
||||
# Create test data
|
||||
|
@ -128,7 +128,7 @@ async def test_inference_store_pagination_with_model_filter():
|
|||
"""Test pagination combined with model filtering."""
|
||||
with TemporaryDirectory() as tmp_dir:
|
||||
db_path = tmp_dir + "/test.db"
|
||||
store = InferenceStore(SqliteSqlStoreConfig(db_path=db_path))
|
||||
store = InferenceStore(SqliteSqlStoreConfig(db_path=db_path), policy=[])
|
||||
await store.initialize()
|
||||
|
||||
# Create test data with different models
|
||||
|
@ -166,7 +166,7 @@ async def test_inference_store_pagination_invalid_after():
|
|||
"""Test error handling for invalid 'after' parameter."""
|
||||
with TemporaryDirectory() as tmp_dir:
|
||||
db_path = tmp_dir + "/test.db"
|
||||
store = InferenceStore(SqliteSqlStoreConfig(db_path=db_path))
|
||||
store = InferenceStore(SqliteSqlStoreConfig(db_path=db_path), policy=[])
|
||||
await store.initialize()
|
||||
|
||||
# Try to paginate with non-existent ID
|
||||
|
@ -179,7 +179,7 @@ async def test_inference_store_pagination_no_limit():
|
|||
"""Test pagination behavior when no limit is specified."""
|
||||
with TemporaryDirectory() as tmp_dir:
|
||||
db_path = tmp_dir + "/test.db"
|
||||
store = InferenceStore(SqliteSqlStoreConfig(db_path=db_path))
|
||||
store = InferenceStore(SqliteSqlStoreConfig(db_path=db_path), policy=[])
|
||||
await store.initialize()
|
||||
|
||||
# Create test data
|
||||
|
|
|
@ -49,7 +49,7 @@ async def test_responses_store_pagination_basic():
|
|||
"""Test basic pagination functionality for responses store."""
|
||||
with TemporaryDirectory() as tmp_dir:
|
||||
db_path = tmp_dir + "/test.db"
|
||||
store = ResponsesStore(SqliteSqlStoreConfig(db_path=db_path))
|
||||
store = ResponsesStore(SqliteSqlStoreConfig(db_path=db_path), policy=[])
|
||||
await store.initialize()
|
||||
|
||||
# Create test data with different timestamps
|
||||
|
@ -95,7 +95,7 @@ async def test_responses_store_pagination_ascending():
|
|||
"""Test pagination with ascending order."""
|
||||
with TemporaryDirectory() as tmp_dir:
|
||||
db_path = tmp_dir + "/test.db"
|
||||
store = ResponsesStore(SqliteSqlStoreConfig(db_path=db_path))
|
||||
store = ResponsesStore(SqliteSqlStoreConfig(db_path=db_path), policy=[])
|
||||
await store.initialize()
|
||||
|
||||
# Create test data
|
||||
|
@ -130,7 +130,7 @@ async def test_responses_store_pagination_with_model_filter():
|
|||
"""Test pagination combined with model filtering."""
|
||||
with TemporaryDirectory() as tmp_dir:
|
||||
db_path = tmp_dir + "/test.db"
|
||||
store = ResponsesStore(SqliteSqlStoreConfig(db_path=db_path))
|
||||
store = ResponsesStore(SqliteSqlStoreConfig(db_path=db_path), policy=[])
|
||||
await store.initialize()
|
||||
|
||||
# Create test data with different models
|
||||
|
@ -168,7 +168,7 @@ async def test_responses_store_pagination_invalid_after():
|
|||
"""Test error handling for invalid 'after' parameter."""
|
||||
with TemporaryDirectory() as tmp_dir:
|
||||
db_path = tmp_dir + "/test.db"
|
||||
store = ResponsesStore(SqliteSqlStoreConfig(db_path=db_path))
|
||||
store = ResponsesStore(SqliteSqlStoreConfig(db_path=db_path), policy=[])
|
||||
await store.initialize()
|
||||
|
||||
# Try to paginate with non-existent ID
|
||||
|
@ -181,7 +181,7 @@ async def test_responses_store_pagination_no_limit():
|
|||
"""Test pagination behavior when no limit is specified."""
|
||||
with TemporaryDirectory() as tmp_dir:
|
||||
db_path = tmp_dir + "/test.db"
|
||||
store = ResponsesStore(SqliteSqlStoreConfig(db_path=db_path))
|
||||
store = ResponsesStore(SqliteSqlStoreConfig(db_path=db_path), policy=[])
|
||||
await store.initialize()
|
||||
|
||||
# Create test data
|
||||
|
@ -210,7 +210,7 @@ async def test_responses_store_get_response_object():
|
|||
"""Test retrieving a single response object."""
|
||||
with TemporaryDirectory() as tmp_dir:
|
||||
db_path = tmp_dir + "/test.db"
|
||||
store = ResponsesStore(SqliteSqlStoreConfig(db_path=db_path))
|
||||
store = ResponsesStore(SqliteSqlStoreConfig(db_path=db_path), policy=[])
|
||||
await store.initialize()
|
||||
|
||||
# Store a test response
|
||||
|
@ -235,7 +235,7 @@ async def test_responses_store_input_items_pagination():
|
|||
"""Test pagination functionality for input items."""
|
||||
with TemporaryDirectory() as tmp_dir:
|
||||
db_path = tmp_dir + "/test.db"
|
||||
store = ResponsesStore(SqliteSqlStoreConfig(db_path=db_path))
|
||||
store = ResponsesStore(SqliteSqlStoreConfig(db_path=db_path), policy=[])
|
||||
await store.initialize()
|
||||
|
||||
# Store a test response with many inputs with explicit IDs
|
||||
|
@ -313,7 +313,7 @@ async def test_responses_store_input_items_before_pagination():
|
|||
"""Test before pagination functionality for input items."""
|
||||
with TemporaryDirectory() as tmp_dir:
|
||||
db_path = tmp_dir + "/test.db"
|
||||
store = ResponsesStore(SqliteSqlStoreConfig(db_path=db_path))
|
||||
store = ResponsesStore(SqliteSqlStoreConfig(db_path=db_path), policy=[])
|
||||
await store.initialize()
|
||||
|
||||
# Store a test response with many inputs with explicit IDs
|
||||
|
|
218
tests/unit/utils/test_authorized_sqlstore.py
Normal file
218
tests/unit/utils/test_authorized_sqlstore.py
Normal file
|
@ -0,0 +1,218 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from tempfile import TemporaryDirectory
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.distribution.access_control.access_control import default_policy, is_action_allowed
|
||||
from llama_stack.distribution.access_control.datatypes import Action
|
||||
from llama_stack.distribution.datatypes import User
|
||||
from llama_stack.providers.utils.sqlstore.api import ColumnType
|
||||
from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore, SqlRecord
|
||||
from llama_stack.providers.utils.sqlstore.sqlalchemy_sqlstore import SqlAlchemySqlStoreImpl
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user")
|
||||
async def test_authorized_fetch_with_where_sql_access_control(mock_get_authenticated_user):
|
||||
"""Test that fetch_all works correctly with where_sql for access control"""
|
||||
with TemporaryDirectory() as tmp_dir:
|
||||
db_name = "test_access_control.db"
|
||||
base_sqlstore = SqlAlchemySqlStoreImpl(
|
||||
SqliteSqlStoreConfig(
|
||||
db_path=tmp_dir + "/" + db_name,
|
||||
)
|
||||
)
|
||||
sqlstore = AuthorizedSqlStore(base_sqlstore)
|
||||
|
||||
# Create table with access control
|
||||
await sqlstore.create_table(
|
||||
table="documents",
|
||||
schema={
|
||||
"id": ColumnType.INTEGER,
|
||||
"title": ColumnType.STRING,
|
||||
"content": ColumnType.TEXT,
|
||||
},
|
||||
)
|
||||
|
||||
admin_user = User("admin-user", {"roles": ["admin"], "teams": ["engineering"]})
|
||||
regular_user = User("regular-user", {"roles": ["user"], "teams": ["marketing"]})
|
||||
|
||||
# Set user attributes for creating documents
|
||||
mock_get_authenticated_user.return_value = admin_user
|
||||
|
||||
# Insert documents with access attributes
|
||||
await sqlstore.insert("documents", {"id": 1, "title": "Admin Document", "content": "This is admin content"})
|
||||
|
||||
# Change user attributes
|
||||
mock_get_authenticated_user.return_value = regular_user
|
||||
|
||||
await sqlstore.insert("documents", {"id": 2, "title": "User Document", "content": "Public user content"})
|
||||
|
||||
# Test that access control works with where parameter
|
||||
mock_get_authenticated_user.return_value = admin_user
|
||||
|
||||
# Admin should see both documents
|
||||
result = await sqlstore.fetch_all("documents", policy=default_policy(), where={"id": 1})
|
||||
assert len(result.data) == 1
|
||||
assert result.data[0]["title"] == "Admin Document"
|
||||
|
||||
# User should only see their document
|
||||
mock_get_authenticated_user.return_value = regular_user
|
||||
|
||||
result = await sqlstore.fetch_all("documents", policy=default_policy(), where={"id": 1})
|
||||
assert len(result.data) == 0
|
||||
|
||||
result = await sqlstore.fetch_all("documents", policy=default_policy(), where={"id": 2})
|
||||
assert len(result.data) == 1
|
||||
assert result.data[0]["title"] == "User Document"
|
||||
|
||||
row = await sqlstore.fetch_one("documents", policy=default_policy(), where={"id": 1})
|
||||
assert row is None
|
||||
|
||||
row = await sqlstore.fetch_one("documents", policy=default_policy(), where={"id": 2})
|
||||
assert row is not None
|
||||
assert row["title"] == "User Document"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user")
|
||||
async def test_sql_policy_consistency(mock_get_authenticated_user):
|
||||
"""Test that SQL WHERE clause logic exactly matches is_action_allowed policy logic"""
|
||||
with TemporaryDirectory() as tmp_dir:
|
||||
db_name = "test_consistency.db"
|
||||
base_sqlstore = SqlAlchemySqlStoreImpl(
|
||||
SqliteSqlStoreConfig(
|
||||
db_path=tmp_dir + "/" + db_name,
|
||||
)
|
||||
)
|
||||
sqlstore = AuthorizedSqlStore(base_sqlstore)
|
||||
|
||||
await sqlstore.create_table(
|
||||
table="resources",
|
||||
schema={
|
||||
"id": ColumnType.STRING,
|
||||
"name": ColumnType.STRING,
|
||||
},
|
||||
)
|
||||
|
||||
# Test scenarios with different access control patterns
|
||||
test_scenarios = [
|
||||
# Scenario 1: Public record (no access control)
|
||||
{"id": "1", "name": "public", "access_attributes": None},
|
||||
# Scenario 2: Empty access control (should be treated as public)
|
||||
{"id": "2", "name": "empty", "access_attributes": {}},
|
||||
# Scenario 3: Record with roles requirement
|
||||
{"id": "3", "name": "admin-only", "access_attributes": {"roles": ["admin"]}},
|
||||
# Scenario 4: Record with multiple attribute categories
|
||||
{"id": "4", "name": "admin-ml-team", "access_attributes": {"roles": ["admin"], "teams": ["ml-team"]}},
|
||||
# Scenario 5: Record with teams only (missing roles category)
|
||||
{"id": "5", "name": "ml-team-only", "access_attributes": {"teams": ["ml-team"]}},
|
||||
# Scenario 6: Record with roles and projects
|
||||
{
|
||||
"id": "6",
|
||||
"name": "admin-project-x",
|
||||
"access_attributes": {"roles": ["admin"], "projects": ["project-x"]},
|
||||
},
|
||||
]
|
||||
|
||||
mock_get_authenticated_user.return_value = User("test-user", {"roles": ["admin"]})
|
||||
for scenario in test_scenarios:
|
||||
await base_sqlstore.insert("resources", scenario)
|
||||
|
||||
# Test with different user configurations
|
||||
user_scenarios = [
|
||||
# User 1: No attributes (should only see public records)
|
||||
{"principal": "user1", "attributes": None},
|
||||
# User 2: Empty attributes (should only see public records)
|
||||
{"principal": "user2", "attributes": {}},
|
||||
# User 3: Admin role only
|
||||
{"principal": "user3", "attributes": {"roles": ["admin"]}},
|
||||
# User 4: ML team only
|
||||
{"principal": "user4", "attributes": {"teams": ["ml-team"]}},
|
||||
# User 5: Admin + ML team
|
||||
{"principal": "user5", "attributes": {"roles": ["admin"], "teams": ["ml-team"]}},
|
||||
# User 6: Admin + Project X
|
||||
{"principal": "user6", "attributes": {"roles": ["admin"], "projects": ["project-x"]}},
|
||||
# User 7: Different role (should only see public)
|
||||
{"principal": "user7", "attributes": {"roles": ["viewer"]}},
|
||||
]
|
||||
|
||||
policy = default_policy()
|
||||
|
||||
for user_data in user_scenarios:
|
||||
user = User(principal=user_data["principal"], attributes=user_data["attributes"])
|
||||
mock_get_authenticated_user.return_value = user
|
||||
|
||||
sql_results = await sqlstore.fetch_all("resources", policy=policy)
|
||||
sql_ids = {row["id"] for row in sql_results.data}
|
||||
policy_ids = set()
|
||||
for scenario in test_scenarios:
|
||||
sql_record = SqlRecord(
|
||||
record_id=scenario["id"], table_name="resources", access_attributes=scenario["access_attributes"]
|
||||
)
|
||||
|
||||
if is_action_allowed(policy, Action.READ, sql_record, user):
|
||||
policy_ids.add(scenario["id"])
|
||||
assert sql_ids == policy_ids, (
|
||||
f"Consistency failure for user {user.principal} with attributes {user.attributes}:\n"
|
||||
f"SQL returned: {sorted(sql_ids)}\n"
|
||||
f"Policy allows: {sorted(policy_ids)}\n"
|
||||
f"Difference: SQL only: {sql_ids - policy_ids}, Policy only: {policy_ids - sql_ids}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user")
|
||||
async def test_authorized_store_user_attribute_capture(mock_get_authenticated_user):
|
||||
"""Test that user attributes are properly captured during insert"""
|
||||
with TemporaryDirectory() as tmp_dir:
|
||||
db_name = "test_attributes.db"
|
||||
base_sqlstore = SqlAlchemySqlStoreImpl(
|
||||
SqliteSqlStoreConfig(
|
||||
db_path=tmp_dir + "/" + db_name,
|
||||
)
|
||||
)
|
||||
authorized_store = AuthorizedSqlStore(base_sqlstore)
|
||||
|
||||
await authorized_store.create_table(
|
||||
table="user_data",
|
||||
schema={
|
||||
"id": ColumnType.STRING,
|
||||
"content": ColumnType.STRING,
|
||||
},
|
||||
)
|
||||
|
||||
mock_get_authenticated_user.return_value = User(
|
||||
"user-with-attrs", {"roles": ["editor"], "teams": ["content"], "projects": ["blog"]}
|
||||
)
|
||||
|
||||
await authorized_store.insert("user_data", {"id": "item1", "content": "User content"})
|
||||
|
||||
mock_get_authenticated_user.return_value = User("user-no-attrs", None)
|
||||
|
||||
await authorized_store.insert("user_data", {"id": "item2", "content": "Public content"})
|
||||
|
||||
mock_get_authenticated_user.return_value = None
|
||||
|
||||
await authorized_store.insert("user_data", {"id": "item3", "content": "Anonymous content"})
|
||||
result = await base_sqlstore.fetch_all("user_data", order_by=[("id", "asc")])
|
||||
assert len(result.data) == 3
|
||||
|
||||
# First item should have full attributes
|
||||
assert result.data[0]["id"] == "item1"
|
||||
assert result.data[0]["access_attributes"] == {"roles": ["editor"], "teams": ["content"], "projects": ["blog"]}
|
||||
|
||||
# Second item should have null attributes (user with no attributes)
|
||||
assert result.data[1]["id"] == "item2"
|
||||
assert result.data[1]["access_attributes"] is None
|
||||
|
||||
# Third item should have null attributes (no authenticated user)
|
||||
assert result.data[2]["id"] == "item3"
|
||||
assert result.data[2]["access_attributes"] is None
|
Loading…
Add table
Add a link
Reference in a new issue