chore: simplify authorized sqlstore (#3496)
Some checks failed
SqlStore Integration Tests / test-postgres (3.12) (push) Failing after 1s
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 1s
SqlStore Integration Tests / test-postgres (3.13) (push) Failing after 0s
Test External Providers Installed via Module / test-external-providers-from-module (venv) (push) Has been skipped
Python Package Build Test / build (3.12) (push) Failing after 1s
Python Package Build Test / build (3.13) (push) Failing after 1s
Integration Tests (Replay) / Integration Tests (, , , client=, ) (push) Failing after 2s
Unit Tests / unit-tests (3.13) (push) Failing after 3s
Update ReadTheDocs / update-readthedocs (push) Failing after 3s
Test External API and Providers / test-external (venv) (push) Failing after 4s
Vector IO Integration Tests / test-matrix (push) Failing after 4s
UI Tests / ui-tests (22) (push) Successful in 35s
API Conformance Tests / check-schema-compatibility (push) Successful in 6s
Unit Tests / unit-tests (3.12) (push) Failing after 3s
Pre-commit / pre-commit (push) Successful in 1m19s

# What does this PR do?

This PR is generated with AI and reviewed by me.

Refactors the AuthorizedSqlStore class to store the access policy as an
instance variable rather than passing it as a parameter to each method
call. This simplifies the API.

# Test Plan

existing tests
This commit is contained in:
ehhuang 2025-09-19 16:13:56 -07:00 committed by GitHub
parent d3600b92d1
commit f44eb935c4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 32 additions and 37 deletions

View file

@ -44,7 +44,7 @@ class LocalfsFilesImpl(Files):
storage_path.mkdir(parents=True, exist_ok=True)
# Initialize SQL store for metadata
self.sql_store = AuthorizedSqlStore(sqlstore_impl(self.config.metadata_store))
self.sql_store = AuthorizedSqlStore(sqlstore_impl(self.config.metadata_store), self.policy)
await self.sql_store.create_table(
"openai_files",
{
@ -74,7 +74,7 @@ class LocalfsFilesImpl(Files):
if not self.sql_store:
raise RuntimeError("Files provider not initialized")
row = await self.sql_store.fetch_one("openai_files", policy=self.policy, where={"id": file_id})
row = await self.sql_store.fetch_one("openai_files", where={"id": file_id})
if not row:
raise ResourceNotFoundError(file_id, "File", "client.files.list()")
@ -150,7 +150,6 @@ class LocalfsFilesImpl(Files):
paginated_result = await self.sql_store.fetch_all(
table="openai_files",
policy=self.policy,
where=where_conditions if where_conditions else None,
order_by=[("created_at", order.value)],
cursor=("id", after) if after else None,

View file

@ -137,7 +137,7 @@ class S3FilesImpl(Files):
where: dict[str, str | dict] = {"id": file_id}
if not return_expired:
where["expires_at"] = {">": self._now()}
if not (row := await self.sql_store.fetch_one("openai_files", policy=self.policy, where=where)):
if not (row := await self.sql_store.fetch_one("openai_files", where=where)):
raise ResourceNotFoundError(file_id, "File", "files.list()")
return row
@ -164,7 +164,7 @@ class S3FilesImpl(Files):
self._client = _create_s3_client(self._config)
await _create_bucket_if_not_exists(self._client, self._config)
self._sql_store = AuthorizedSqlStore(sqlstore_impl(self._config.metadata_store))
self._sql_store = AuthorizedSqlStore(sqlstore_impl(self._config.metadata_store), self.policy)
await self._sql_store.create_table(
"openai_files",
{
@ -268,7 +268,6 @@ class S3FilesImpl(Files):
paginated_result = await self.sql_store.fetch_all(
table="openai_files",
policy=self.policy,
where=where_conditions,
order_by=[("created_at", order.value)],
cursor=("id", after) if after else None,

View file

@ -54,7 +54,7 @@ class InferenceStore:
async def initialize(self):
"""Create the necessary tables if they don't exist."""
self.sql_store = AuthorizedSqlStore(sqlstore_impl(self.sql_store_config))
self.sql_store = AuthorizedSqlStore(sqlstore_impl(self.sql_store_config), self.policy)
await self.sql_store.create_table(
"chat_completions",
{
@ -202,7 +202,6 @@ class InferenceStore:
order_by=[("created", order.value)],
cursor=("id", after) if after else None,
limit=limit,
policy=self.policy,
)
data = [
@ -229,7 +228,6 @@ class InferenceStore:
row = await self.sql_store.fetch_one(
table="chat_completions",
where={"id": completion_id},
policy=self.policy,
)
if not row:

View file

@ -28,8 +28,7 @@ class ResponsesStore:
sql_store_config = SqliteSqlStoreConfig(
db_path=(RUNTIME_BASE_DIR / "sqlstore.db").as_posix(),
)
self.sql_store = AuthorizedSqlStore(sqlstore_impl(sql_store_config))
self.policy = policy
self.sql_store = AuthorizedSqlStore(sqlstore_impl(sql_store_config), policy)
async def initialize(self):
"""Create the necessary tables if they don't exist."""
@ -87,7 +86,6 @@ class ResponsesStore:
order_by=[("created_at", order.value)],
cursor=("id", after) if after else None,
limit=limit,
policy=self.policy,
)
data = [OpenAIResponseObjectWithInput(**row["response_object"]) for row in paginated_result.data]
@ -105,7 +103,6 @@ class ResponsesStore:
row = await self.sql_store.fetch_one(
"openai_responses",
where={"id": response_id},
policy=self.policy,
)
if not row:
@ -116,7 +113,7 @@ class ResponsesStore:
return OpenAIResponseObjectWithInput(**row["response_object"])
async def delete_response_object(self, response_id: str) -> OpenAIDeleteResponseObject:
row = await self.sql_store.fetch_one("openai_responses", where={"id": response_id}, policy=self.policy)
row = await self.sql_store.fetch_one("openai_responses", where={"id": response_id})
if not row:
raise ValueError(f"Response with id {response_id} not found")
await self.sql_store.delete("openai_responses", where={"id": response_id})

View file

@ -53,13 +53,15 @@ class AuthorizedSqlStore:
access control policies, user attribute capture, and SQL filtering optimization.
"""
def __init__(self, sql_store: SqlStore):
def __init__(self, sql_store: SqlStore, policy: list[AccessRule]):
"""
Initialize the authorization layer.
:param sql_store: Base SqlStore implementation to wrap
:param policy: Access control policy to use for authorization
"""
self.sql_store = sql_store
self.policy = policy
self._detect_database_type()
self._validate_sql_optimized_policy()
@ -117,14 +119,13 @@ class AuthorizedSqlStore:
async def fetch_all(
self,
table: str,
policy: list[AccessRule],
where: Mapping[str, Any] | None = None,
limit: int | None = None,
order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None,
cursor: tuple[str, str] | None = None,
) -> PaginatedResponse:
"""Fetch all rows with automatic access control filtering."""
access_where = self._build_access_control_where_clause(policy)
access_where = self._build_access_control_where_clause(self.policy)
rows = await self.sql_store.fetch_all(
table=table,
where=where,
@ -146,7 +147,7 @@ class AuthorizedSqlStore:
str(record_id), table, User(principal=stored_owner_principal, attributes=stored_access_attrs)
)
if is_action_allowed(policy, Action.READ, sql_record, current_user):
if is_action_allowed(self.policy, Action.READ, sql_record, current_user):
filtered_rows.append(row)
return PaginatedResponse(
@ -157,14 +158,12 @@ class AuthorizedSqlStore:
async def fetch_one(
self,
table: str,
policy: list[AccessRule],
where: Mapping[str, Any] | None = None,
order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None,
) -> dict[str, Any] | None:
"""Fetch one row with automatic access control checking."""
results = await self.fetch_all(
table=table,
policy=policy,
where=where,
limit=1,
order_by=order_by,

View file

@ -57,7 +57,7 @@ def authorized_store(backend_config):
config = config_func()
base_sqlstore = sqlstore_impl(config)
authorized_store = AuthorizedSqlStore(base_sqlstore)
authorized_store = AuthorizedSqlStore(base_sqlstore, default_policy())
yield authorized_store
@ -106,7 +106,7 @@ async def test_authorized_store_attributes(mock_get_authenticated_user, authoriz
await authorized_store.insert(table_name, {"id": "1", "data": "public_data"})
# Test fetching with no user - should not error on JSON comparison
result = await authorized_store.fetch_all(table_name, policy=default_policy())
result = await authorized_store.fetch_all(table_name)
assert len(result.data) == 1
assert result.data[0]["id"] == "1"
assert result.data[0]["access_attributes"] is None
@ -119,7 +119,7 @@ async def test_authorized_store_attributes(mock_get_authenticated_user, authoriz
await authorized_store.insert(table_name, {"id": "2", "data": "admin_data"})
# Fetch all - admin should see both
result = await authorized_store.fetch_all(table_name, policy=default_policy())
result = await authorized_store.fetch_all(table_name)
assert len(result.data) == 2
# Test with non-admin user
@ -127,7 +127,7 @@ async def test_authorized_store_attributes(mock_get_authenticated_user, authoriz
mock_get_authenticated_user.return_value = regular_user
# Should only see public record
result = await authorized_store.fetch_all(table_name, policy=default_policy())
result = await authorized_store.fetch_all(table_name)
assert len(result.data) == 1
assert result.data[0]["id"] == "1"
@ -156,7 +156,7 @@ async def test_authorized_store_attributes(mock_get_authenticated_user, authoriz
# Now test with the multi-user who has both roles=admin and teams=dev
mock_get_authenticated_user.return_value = multi_user
result = await authorized_store.fetch_all(table_name, policy=default_policy())
result = await authorized_store.fetch_all(table_name)
# Should see:
# - public record (1) - no access_attributes
@ -217,21 +217,24 @@ async def test_user_ownership_policy(mock_get_authenticated_user, authorized_sto
),
]
# Create a new authorized store with the owner-only policy
owner_only_store = AuthorizedSqlStore(authorized_store.sql_store, owner_only_policy)
# Test user1 access - should only see their own record
mock_get_authenticated_user.return_value = user1
result = await authorized_store.fetch_all(table_name, policy=owner_only_policy)
result = await owner_only_store.fetch_all(table_name)
assert len(result.data) == 1, f"Expected user1 to see 1 record, got {len(result.data)}"
assert result.data[0]["id"] == "1", f"Expected user1's record, got {result.data[0]['id']}"
# Test user2 access - should only see their own record
mock_get_authenticated_user.return_value = user2
result = await authorized_store.fetch_all(table_name, policy=owner_only_policy)
result = await owner_only_store.fetch_all(table_name)
assert len(result.data) == 1, f"Expected user2 to see 1 record, got {len(result.data)}"
assert result.data[0]["id"] == "2", f"Expected user2's record, got {result.data[0]['id']}"
# Test with anonymous user - should see no records
mock_get_authenticated_user.return_value = None
result = await authorized_store.fetch_all(table_name, policy=owner_only_policy)
result = await owner_only_store.fetch_all(table_name)
assert len(result.data) == 0, f"Expected anonymous user to see 0 records, got {len(result.data)}"
finally:

View file

@ -26,7 +26,7 @@ async def test_authorized_fetch_with_where_sql_access_control(mock_get_authentic
db_path=tmp_dir + "/" + db_name,
)
)
sqlstore = AuthorizedSqlStore(base_sqlstore)
sqlstore = AuthorizedSqlStore(base_sqlstore, default_policy())
# Create table with access control
await sqlstore.create_table(
@ -56,24 +56,24 @@ async def test_authorized_fetch_with_where_sql_access_control(mock_get_authentic
mock_get_authenticated_user.return_value = admin_user
# Admin should see both documents
result = await sqlstore.fetch_all("documents", policy=default_policy(), where={"id": 1})
result = await sqlstore.fetch_all("documents", where={"id": 1})
assert len(result.data) == 1
assert result.data[0]["title"] == "Admin Document"
# User should only see their document
mock_get_authenticated_user.return_value = regular_user
result = await sqlstore.fetch_all("documents", policy=default_policy(), where={"id": 1})
result = await sqlstore.fetch_all("documents", where={"id": 1})
assert len(result.data) == 0
result = await sqlstore.fetch_all("documents", policy=default_policy(), where={"id": 2})
result = await sqlstore.fetch_all("documents", where={"id": 2})
assert len(result.data) == 1
assert result.data[0]["title"] == "User Document"
row = await sqlstore.fetch_one("documents", policy=default_policy(), where={"id": 1})
row = await sqlstore.fetch_one("documents", where={"id": 1})
assert row is None
row = await sqlstore.fetch_one("documents", policy=default_policy(), where={"id": 2})
row = await sqlstore.fetch_one("documents", where={"id": 2})
assert row is not None
assert row["title"] == "User Document"
@ -88,7 +88,7 @@ async def test_sql_policy_consistency(mock_get_authenticated_user):
db_path=tmp_dir + "/" + db_name,
)
)
sqlstore = AuthorizedSqlStore(base_sqlstore)
sqlstore = AuthorizedSqlStore(base_sqlstore, default_policy())
await sqlstore.create_table(
table="resources",
@ -144,7 +144,7 @@ async def test_sql_policy_consistency(mock_get_authenticated_user):
user = User(principal=user_data["principal"], attributes=user_data["attributes"])
mock_get_authenticated_user.return_value = user
sql_results = await sqlstore.fetch_all("resources", policy=policy)
sql_results = await sqlstore.fetch_all("resources")
sql_ids = {row["id"] for row in sql_results.data}
policy_ids = set()
for scenario in test_scenarios:
@ -174,7 +174,7 @@ async def test_authorized_store_user_attribute_capture(mock_get_authenticated_us
db_path=tmp_dir + "/" + db_name,
)
)
authorized_store = AuthorizedSqlStore(base_sqlstore)
authorized_store = AuthorizedSqlStore(base_sqlstore, default_policy())
await authorized_store.create_table(
table="user_data",