chore: simplify authorized sqlstore

# What does this PR do?


## Test Plan
This commit is contained in:
Eric Huang 2025-09-19 14:59:30 -07:00
parent d3600b92d1
commit b4974d411d
7 changed files with 32 additions and 37 deletions

View file

@ -26,7 +26,7 @@ async def test_authorized_fetch_with_where_sql_access_control(mock_get_authentic
db_path=tmp_dir + "/" + db_name,
)
)
sqlstore = AuthorizedSqlStore(base_sqlstore)
sqlstore = AuthorizedSqlStore(base_sqlstore, default_policy())
# Create table with access control
await sqlstore.create_table(
@ -56,24 +56,24 @@ async def test_authorized_fetch_with_where_sql_access_control(mock_get_authentic
mock_get_authenticated_user.return_value = admin_user
# Admin should see both documents
result = await sqlstore.fetch_all("documents", policy=default_policy(), where={"id": 1})
result = await sqlstore.fetch_all("documents", where={"id": 1})
assert len(result.data) == 1
assert result.data[0]["title"] == "Admin Document"
# User should only see their document
mock_get_authenticated_user.return_value = regular_user
result = await sqlstore.fetch_all("documents", policy=default_policy(), where={"id": 1})
result = await sqlstore.fetch_all("documents", where={"id": 1})
assert len(result.data) == 0
result = await sqlstore.fetch_all("documents", policy=default_policy(), where={"id": 2})
result = await sqlstore.fetch_all("documents", where={"id": 2})
assert len(result.data) == 1
assert result.data[0]["title"] == "User Document"
row = await sqlstore.fetch_one("documents", policy=default_policy(), where={"id": 1})
row = await sqlstore.fetch_one("documents", where={"id": 1})
assert row is None
row = await sqlstore.fetch_one("documents", policy=default_policy(), where={"id": 2})
row = await sqlstore.fetch_one("documents", where={"id": 2})
assert row is not None
assert row["title"] == "User Document"
@ -88,7 +88,7 @@ async def test_sql_policy_consistency(mock_get_authenticated_user):
db_path=tmp_dir + "/" + db_name,
)
)
sqlstore = AuthorizedSqlStore(base_sqlstore)
sqlstore = AuthorizedSqlStore(base_sqlstore, default_policy())
await sqlstore.create_table(
table="resources",
@ -144,7 +144,7 @@ async def test_sql_policy_consistency(mock_get_authenticated_user):
user = User(principal=user_data["principal"], attributes=user_data["attributes"])
mock_get_authenticated_user.return_value = user
sql_results = await sqlstore.fetch_all("resources", policy=policy)
sql_results = await sqlstore.fetch_all("resources")
sql_ids = {row["id"] for row in sql_results.data}
policy_ids = set()
for scenario in test_scenarios:
@ -174,7 +174,7 @@ async def test_authorized_store_user_attribute_capture(mock_get_authenticated_us
db_path=tmp_dir + "/" + db_name,
)
)
authorized_store = AuthorizedSqlStore(base_sqlstore)
authorized_store = AuthorizedSqlStore(base_sqlstore, default_policy())
await authorized_store.create_table(
table="user_data",