mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 20:14:13 +00:00
chore: simplify authorized sqlstore
# What does this PR do? ## Test Plan
This commit is contained in:
parent
d3600b92d1
commit
b4974d411d
7 changed files with 32 additions and 37 deletions
|
@ -44,7 +44,7 @@ class LocalfsFilesImpl(Files):
|
||||||
storage_path.mkdir(parents=True, exist_ok=True)
|
storage_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
# Initialize SQL store for metadata
|
# Initialize SQL store for metadata
|
||||||
self.sql_store = AuthorizedSqlStore(sqlstore_impl(self.config.metadata_store))
|
self.sql_store = AuthorizedSqlStore(sqlstore_impl(self.config.metadata_store), self.policy)
|
||||||
await self.sql_store.create_table(
|
await self.sql_store.create_table(
|
||||||
"openai_files",
|
"openai_files",
|
||||||
{
|
{
|
||||||
|
@ -74,7 +74,7 @@ class LocalfsFilesImpl(Files):
|
||||||
if not self.sql_store:
|
if not self.sql_store:
|
||||||
raise RuntimeError("Files provider not initialized")
|
raise RuntimeError("Files provider not initialized")
|
||||||
|
|
||||||
row = await self.sql_store.fetch_one("openai_files", policy=self.policy, where={"id": file_id})
|
row = await self.sql_store.fetch_one("openai_files", where={"id": file_id})
|
||||||
if not row:
|
if not row:
|
||||||
raise ResourceNotFoundError(file_id, "File", "client.files.list()")
|
raise ResourceNotFoundError(file_id, "File", "client.files.list()")
|
||||||
|
|
||||||
|
@ -150,7 +150,6 @@ class LocalfsFilesImpl(Files):
|
||||||
|
|
||||||
paginated_result = await self.sql_store.fetch_all(
|
paginated_result = await self.sql_store.fetch_all(
|
||||||
table="openai_files",
|
table="openai_files",
|
||||||
policy=self.policy,
|
|
||||||
where=where_conditions if where_conditions else None,
|
where=where_conditions if where_conditions else None,
|
||||||
order_by=[("created_at", order.value)],
|
order_by=[("created_at", order.value)],
|
||||||
cursor=("id", after) if after else None,
|
cursor=("id", after) if after else None,
|
||||||
|
|
|
@ -137,7 +137,7 @@ class S3FilesImpl(Files):
|
||||||
where: dict[str, str | dict] = {"id": file_id}
|
where: dict[str, str | dict] = {"id": file_id}
|
||||||
if not return_expired:
|
if not return_expired:
|
||||||
where["expires_at"] = {">": self._now()}
|
where["expires_at"] = {">": self._now()}
|
||||||
if not (row := await self.sql_store.fetch_one("openai_files", policy=self.policy, where=where)):
|
if not (row := await self.sql_store.fetch_one("openai_files", where=where)):
|
||||||
raise ResourceNotFoundError(file_id, "File", "files.list()")
|
raise ResourceNotFoundError(file_id, "File", "files.list()")
|
||||||
return row
|
return row
|
||||||
|
|
||||||
|
@ -164,7 +164,7 @@ class S3FilesImpl(Files):
|
||||||
self._client = _create_s3_client(self._config)
|
self._client = _create_s3_client(self._config)
|
||||||
await _create_bucket_if_not_exists(self._client, self._config)
|
await _create_bucket_if_not_exists(self._client, self._config)
|
||||||
|
|
||||||
self._sql_store = AuthorizedSqlStore(sqlstore_impl(self._config.metadata_store))
|
self._sql_store = AuthorizedSqlStore(sqlstore_impl(self._config.metadata_store), self.policy)
|
||||||
await self._sql_store.create_table(
|
await self._sql_store.create_table(
|
||||||
"openai_files",
|
"openai_files",
|
||||||
{
|
{
|
||||||
|
@ -268,7 +268,6 @@ class S3FilesImpl(Files):
|
||||||
|
|
||||||
paginated_result = await self.sql_store.fetch_all(
|
paginated_result = await self.sql_store.fetch_all(
|
||||||
table="openai_files",
|
table="openai_files",
|
||||||
policy=self.policy,
|
|
||||||
where=where_conditions,
|
where=where_conditions,
|
||||||
order_by=[("created_at", order.value)],
|
order_by=[("created_at", order.value)],
|
||||||
cursor=("id", after) if after else None,
|
cursor=("id", after) if after else None,
|
||||||
|
|
|
@ -54,7 +54,7 @@ class InferenceStore:
|
||||||
|
|
||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
"""Create the necessary tables if they don't exist."""
|
"""Create the necessary tables if they don't exist."""
|
||||||
self.sql_store = AuthorizedSqlStore(sqlstore_impl(self.sql_store_config))
|
self.sql_store = AuthorizedSqlStore(sqlstore_impl(self.sql_store_config), self.policy)
|
||||||
await self.sql_store.create_table(
|
await self.sql_store.create_table(
|
||||||
"chat_completions",
|
"chat_completions",
|
||||||
{
|
{
|
||||||
|
@ -202,7 +202,6 @@ class InferenceStore:
|
||||||
order_by=[("created", order.value)],
|
order_by=[("created", order.value)],
|
||||||
cursor=("id", after) if after else None,
|
cursor=("id", after) if after else None,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
policy=self.policy,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
data = [
|
data = [
|
||||||
|
@ -229,7 +228,6 @@ class InferenceStore:
|
||||||
row = await self.sql_store.fetch_one(
|
row = await self.sql_store.fetch_one(
|
||||||
table="chat_completions",
|
table="chat_completions",
|
||||||
where={"id": completion_id},
|
where={"id": completion_id},
|
||||||
policy=self.policy,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if not row:
|
if not row:
|
||||||
|
|
|
@ -28,8 +28,7 @@ class ResponsesStore:
|
||||||
sql_store_config = SqliteSqlStoreConfig(
|
sql_store_config = SqliteSqlStoreConfig(
|
||||||
db_path=(RUNTIME_BASE_DIR / "sqlstore.db").as_posix(),
|
db_path=(RUNTIME_BASE_DIR / "sqlstore.db").as_posix(),
|
||||||
)
|
)
|
||||||
self.sql_store = AuthorizedSqlStore(sqlstore_impl(sql_store_config))
|
self.sql_store = AuthorizedSqlStore(sqlstore_impl(sql_store_config), policy)
|
||||||
self.policy = policy
|
|
||||||
|
|
||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
"""Create the necessary tables if they don't exist."""
|
"""Create the necessary tables if they don't exist."""
|
||||||
|
@ -87,7 +86,6 @@ class ResponsesStore:
|
||||||
order_by=[("created_at", order.value)],
|
order_by=[("created_at", order.value)],
|
||||||
cursor=("id", after) if after else None,
|
cursor=("id", after) if after else None,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
policy=self.policy,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
data = [OpenAIResponseObjectWithInput(**row["response_object"]) for row in paginated_result.data]
|
data = [OpenAIResponseObjectWithInput(**row["response_object"]) for row in paginated_result.data]
|
||||||
|
@ -105,7 +103,6 @@ class ResponsesStore:
|
||||||
row = await self.sql_store.fetch_one(
|
row = await self.sql_store.fetch_one(
|
||||||
"openai_responses",
|
"openai_responses",
|
||||||
where={"id": response_id},
|
where={"id": response_id},
|
||||||
policy=self.policy,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if not row:
|
if not row:
|
||||||
|
@ -116,7 +113,7 @@ class ResponsesStore:
|
||||||
return OpenAIResponseObjectWithInput(**row["response_object"])
|
return OpenAIResponseObjectWithInput(**row["response_object"])
|
||||||
|
|
||||||
async def delete_response_object(self, response_id: str) -> OpenAIDeleteResponseObject:
|
async def delete_response_object(self, response_id: str) -> OpenAIDeleteResponseObject:
|
||||||
row = await self.sql_store.fetch_one("openai_responses", where={"id": response_id}, policy=self.policy)
|
row = await self.sql_store.fetch_one("openai_responses", where={"id": response_id})
|
||||||
if not row:
|
if not row:
|
||||||
raise ValueError(f"Response with id {response_id} not found")
|
raise ValueError(f"Response with id {response_id} not found")
|
||||||
await self.sql_store.delete("openai_responses", where={"id": response_id})
|
await self.sql_store.delete("openai_responses", where={"id": response_id})
|
||||||
|
|
|
@ -53,13 +53,15 @@ class AuthorizedSqlStore:
|
||||||
access control policies, user attribute capture, and SQL filtering optimization.
|
access control policies, user attribute capture, and SQL filtering optimization.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, sql_store: SqlStore):
|
def __init__(self, sql_store: SqlStore, policy: list[AccessRule]):
|
||||||
"""
|
"""
|
||||||
Initialize the authorization layer.
|
Initialize the authorization layer.
|
||||||
|
|
||||||
:param sql_store: Base SqlStore implementation to wrap
|
:param sql_store: Base SqlStore implementation to wrap
|
||||||
|
:param policy: Access control policy to use for authorization
|
||||||
"""
|
"""
|
||||||
self.sql_store = sql_store
|
self.sql_store = sql_store
|
||||||
|
self.policy = policy
|
||||||
self._detect_database_type()
|
self._detect_database_type()
|
||||||
self._validate_sql_optimized_policy()
|
self._validate_sql_optimized_policy()
|
||||||
|
|
||||||
|
@ -117,14 +119,13 @@ class AuthorizedSqlStore:
|
||||||
async def fetch_all(
|
async def fetch_all(
|
||||||
self,
|
self,
|
||||||
table: str,
|
table: str,
|
||||||
policy: list[AccessRule],
|
|
||||||
where: Mapping[str, Any] | None = None,
|
where: Mapping[str, Any] | None = None,
|
||||||
limit: int | None = None,
|
limit: int | None = None,
|
||||||
order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None,
|
order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None,
|
||||||
cursor: tuple[str, str] | None = None,
|
cursor: tuple[str, str] | None = None,
|
||||||
) -> PaginatedResponse:
|
) -> PaginatedResponse:
|
||||||
"""Fetch all rows with automatic access control filtering."""
|
"""Fetch all rows with automatic access control filtering."""
|
||||||
access_where = self._build_access_control_where_clause(policy)
|
access_where = self._build_access_control_where_clause(self.policy)
|
||||||
rows = await self.sql_store.fetch_all(
|
rows = await self.sql_store.fetch_all(
|
||||||
table=table,
|
table=table,
|
||||||
where=where,
|
where=where,
|
||||||
|
@ -146,7 +147,7 @@ class AuthorizedSqlStore:
|
||||||
str(record_id), table, User(principal=stored_owner_principal, attributes=stored_access_attrs)
|
str(record_id), table, User(principal=stored_owner_principal, attributes=stored_access_attrs)
|
||||||
)
|
)
|
||||||
|
|
||||||
if is_action_allowed(policy, Action.READ, sql_record, current_user):
|
if is_action_allowed(self.policy, Action.READ, sql_record, current_user):
|
||||||
filtered_rows.append(row)
|
filtered_rows.append(row)
|
||||||
|
|
||||||
return PaginatedResponse(
|
return PaginatedResponse(
|
||||||
|
@ -157,14 +158,12 @@ class AuthorizedSqlStore:
|
||||||
async def fetch_one(
|
async def fetch_one(
|
||||||
self,
|
self,
|
||||||
table: str,
|
table: str,
|
||||||
policy: list[AccessRule],
|
|
||||||
where: Mapping[str, Any] | None = None,
|
where: Mapping[str, Any] | None = None,
|
||||||
order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None,
|
order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None,
|
||||||
) -> dict[str, Any] | None:
|
) -> dict[str, Any] | None:
|
||||||
"""Fetch one row with automatic access control checking."""
|
"""Fetch one row with automatic access control checking."""
|
||||||
results = await self.fetch_all(
|
results = await self.fetch_all(
|
||||||
table=table,
|
table=table,
|
||||||
policy=policy,
|
|
||||||
where=where,
|
where=where,
|
||||||
limit=1,
|
limit=1,
|
||||||
order_by=order_by,
|
order_by=order_by,
|
||||||
|
|
|
@ -57,7 +57,7 @@ def authorized_store(backend_config):
|
||||||
config = config_func()
|
config = config_func()
|
||||||
|
|
||||||
base_sqlstore = sqlstore_impl(config)
|
base_sqlstore = sqlstore_impl(config)
|
||||||
authorized_store = AuthorizedSqlStore(base_sqlstore)
|
authorized_store = AuthorizedSqlStore(base_sqlstore, default_policy())
|
||||||
|
|
||||||
yield authorized_store
|
yield authorized_store
|
||||||
|
|
||||||
|
@ -106,7 +106,7 @@ async def test_authorized_store_attributes(mock_get_authenticated_user, authoriz
|
||||||
await authorized_store.insert(table_name, {"id": "1", "data": "public_data"})
|
await authorized_store.insert(table_name, {"id": "1", "data": "public_data"})
|
||||||
|
|
||||||
# Test fetching with no user - should not error on JSON comparison
|
# Test fetching with no user - should not error on JSON comparison
|
||||||
result = await authorized_store.fetch_all(table_name, policy=default_policy())
|
result = await authorized_store.fetch_all(table_name)
|
||||||
assert len(result.data) == 1
|
assert len(result.data) == 1
|
||||||
assert result.data[0]["id"] == "1"
|
assert result.data[0]["id"] == "1"
|
||||||
assert result.data[0]["access_attributes"] is None
|
assert result.data[0]["access_attributes"] is None
|
||||||
|
@ -119,7 +119,7 @@ async def test_authorized_store_attributes(mock_get_authenticated_user, authoriz
|
||||||
await authorized_store.insert(table_name, {"id": "2", "data": "admin_data"})
|
await authorized_store.insert(table_name, {"id": "2", "data": "admin_data"})
|
||||||
|
|
||||||
# Fetch all - admin should see both
|
# Fetch all - admin should see both
|
||||||
result = await authorized_store.fetch_all(table_name, policy=default_policy())
|
result = await authorized_store.fetch_all(table_name)
|
||||||
assert len(result.data) == 2
|
assert len(result.data) == 2
|
||||||
|
|
||||||
# Test with non-admin user
|
# Test with non-admin user
|
||||||
|
@ -127,7 +127,7 @@ async def test_authorized_store_attributes(mock_get_authenticated_user, authoriz
|
||||||
mock_get_authenticated_user.return_value = regular_user
|
mock_get_authenticated_user.return_value = regular_user
|
||||||
|
|
||||||
# Should only see public record
|
# Should only see public record
|
||||||
result = await authorized_store.fetch_all(table_name, policy=default_policy())
|
result = await authorized_store.fetch_all(table_name)
|
||||||
assert len(result.data) == 1
|
assert len(result.data) == 1
|
||||||
assert result.data[0]["id"] == "1"
|
assert result.data[0]["id"] == "1"
|
||||||
|
|
||||||
|
@ -156,7 +156,7 @@ async def test_authorized_store_attributes(mock_get_authenticated_user, authoriz
|
||||||
|
|
||||||
# Now test with the multi-user who has both roles=admin and teams=dev
|
# Now test with the multi-user who has both roles=admin and teams=dev
|
||||||
mock_get_authenticated_user.return_value = multi_user
|
mock_get_authenticated_user.return_value = multi_user
|
||||||
result = await authorized_store.fetch_all(table_name, policy=default_policy())
|
result = await authorized_store.fetch_all(table_name)
|
||||||
|
|
||||||
# Should see:
|
# Should see:
|
||||||
# - public record (1) - no access_attributes
|
# - public record (1) - no access_attributes
|
||||||
|
@ -217,21 +217,24 @@ async def test_user_ownership_policy(mock_get_authenticated_user, authorized_sto
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# Create a new authorized store with the owner-only policy
|
||||||
|
owner_only_store = AuthorizedSqlStore(authorized_store.sql_store, owner_only_policy)
|
||||||
|
|
||||||
# Test user1 access - should only see their own record
|
# Test user1 access - should only see their own record
|
||||||
mock_get_authenticated_user.return_value = user1
|
mock_get_authenticated_user.return_value = user1
|
||||||
result = await authorized_store.fetch_all(table_name, policy=owner_only_policy)
|
result = await owner_only_store.fetch_all(table_name)
|
||||||
assert len(result.data) == 1, f"Expected user1 to see 1 record, got {len(result.data)}"
|
assert len(result.data) == 1, f"Expected user1 to see 1 record, got {len(result.data)}"
|
||||||
assert result.data[0]["id"] == "1", f"Expected user1's record, got {result.data[0]['id']}"
|
assert result.data[0]["id"] == "1", f"Expected user1's record, got {result.data[0]['id']}"
|
||||||
|
|
||||||
# Test user2 access - should only see their own record
|
# Test user2 access - should only see their own record
|
||||||
mock_get_authenticated_user.return_value = user2
|
mock_get_authenticated_user.return_value = user2
|
||||||
result = await authorized_store.fetch_all(table_name, policy=owner_only_policy)
|
result = await owner_only_store.fetch_all(table_name)
|
||||||
assert len(result.data) == 1, f"Expected user2 to see 1 record, got {len(result.data)}"
|
assert len(result.data) == 1, f"Expected user2 to see 1 record, got {len(result.data)}"
|
||||||
assert result.data[0]["id"] == "2", f"Expected user2's record, got {result.data[0]['id']}"
|
assert result.data[0]["id"] == "2", f"Expected user2's record, got {result.data[0]['id']}"
|
||||||
|
|
||||||
# Test with anonymous user - should see no records
|
# Test with anonymous user - should see no records
|
||||||
mock_get_authenticated_user.return_value = None
|
mock_get_authenticated_user.return_value = None
|
||||||
result = await authorized_store.fetch_all(table_name, policy=owner_only_policy)
|
result = await owner_only_store.fetch_all(table_name)
|
||||||
assert len(result.data) == 0, f"Expected anonymous user to see 0 records, got {len(result.data)}"
|
assert len(result.data) == 0, f"Expected anonymous user to see 0 records, got {len(result.data)}"
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
|
|
|
@ -26,7 +26,7 @@ async def test_authorized_fetch_with_where_sql_access_control(mock_get_authentic
|
||||||
db_path=tmp_dir + "/" + db_name,
|
db_path=tmp_dir + "/" + db_name,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
sqlstore = AuthorizedSqlStore(base_sqlstore)
|
sqlstore = AuthorizedSqlStore(base_sqlstore, default_policy())
|
||||||
|
|
||||||
# Create table with access control
|
# Create table with access control
|
||||||
await sqlstore.create_table(
|
await sqlstore.create_table(
|
||||||
|
@ -56,24 +56,24 @@ async def test_authorized_fetch_with_where_sql_access_control(mock_get_authentic
|
||||||
mock_get_authenticated_user.return_value = admin_user
|
mock_get_authenticated_user.return_value = admin_user
|
||||||
|
|
||||||
# Admin should see both documents
|
# Admin should see both documents
|
||||||
result = await sqlstore.fetch_all("documents", policy=default_policy(), where={"id": 1})
|
result = await sqlstore.fetch_all("documents", where={"id": 1})
|
||||||
assert len(result.data) == 1
|
assert len(result.data) == 1
|
||||||
assert result.data[0]["title"] == "Admin Document"
|
assert result.data[0]["title"] == "Admin Document"
|
||||||
|
|
||||||
# User should only see their document
|
# User should only see their document
|
||||||
mock_get_authenticated_user.return_value = regular_user
|
mock_get_authenticated_user.return_value = regular_user
|
||||||
|
|
||||||
result = await sqlstore.fetch_all("documents", policy=default_policy(), where={"id": 1})
|
result = await sqlstore.fetch_all("documents", where={"id": 1})
|
||||||
assert len(result.data) == 0
|
assert len(result.data) == 0
|
||||||
|
|
||||||
result = await sqlstore.fetch_all("documents", policy=default_policy(), where={"id": 2})
|
result = await sqlstore.fetch_all("documents", where={"id": 2})
|
||||||
assert len(result.data) == 1
|
assert len(result.data) == 1
|
||||||
assert result.data[0]["title"] == "User Document"
|
assert result.data[0]["title"] == "User Document"
|
||||||
|
|
||||||
row = await sqlstore.fetch_one("documents", policy=default_policy(), where={"id": 1})
|
row = await sqlstore.fetch_one("documents", where={"id": 1})
|
||||||
assert row is None
|
assert row is None
|
||||||
|
|
||||||
row = await sqlstore.fetch_one("documents", policy=default_policy(), where={"id": 2})
|
row = await sqlstore.fetch_one("documents", where={"id": 2})
|
||||||
assert row is not None
|
assert row is not None
|
||||||
assert row["title"] == "User Document"
|
assert row["title"] == "User Document"
|
||||||
|
|
||||||
|
@ -88,7 +88,7 @@ async def test_sql_policy_consistency(mock_get_authenticated_user):
|
||||||
db_path=tmp_dir + "/" + db_name,
|
db_path=tmp_dir + "/" + db_name,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
sqlstore = AuthorizedSqlStore(base_sqlstore)
|
sqlstore = AuthorizedSqlStore(base_sqlstore, default_policy())
|
||||||
|
|
||||||
await sqlstore.create_table(
|
await sqlstore.create_table(
|
||||||
table="resources",
|
table="resources",
|
||||||
|
@ -144,7 +144,7 @@ async def test_sql_policy_consistency(mock_get_authenticated_user):
|
||||||
user = User(principal=user_data["principal"], attributes=user_data["attributes"])
|
user = User(principal=user_data["principal"], attributes=user_data["attributes"])
|
||||||
mock_get_authenticated_user.return_value = user
|
mock_get_authenticated_user.return_value = user
|
||||||
|
|
||||||
sql_results = await sqlstore.fetch_all("resources", policy=policy)
|
sql_results = await sqlstore.fetch_all("resources")
|
||||||
sql_ids = {row["id"] for row in sql_results.data}
|
sql_ids = {row["id"] for row in sql_results.data}
|
||||||
policy_ids = set()
|
policy_ids = set()
|
||||||
for scenario in test_scenarios:
|
for scenario in test_scenarios:
|
||||||
|
@ -174,7 +174,7 @@ async def test_authorized_store_user_attribute_capture(mock_get_authenticated_us
|
||||||
db_path=tmp_dir + "/" + db_name,
|
db_path=tmp_dir + "/" + db_name,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
authorized_store = AuthorizedSqlStore(base_sqlstore)
|
authorized_store = AuthorizedSqlStore(base_sqlstore, default_policy())
|
||||||
|
|
||||||
await authorized_store.create_table(
|
await authorized_store.create_table(
|
||||||
table="user_data",
|
table="user_data",
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue