chore: simplify authorized sqlstore (#3496)
Some checks failed
SqlStore Integration Tests / test-postgres (3.12) (push) Failing after 1s
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 1s
SqlStore Integration Tests / test-postgres (3.13) (push) Failing after 0s
Test External Providers Installed via Module / test-external-providers-from-module (venv) (push) Has been skipped
Python Package Build Test / build (3.12) (push) Failing after 1s
Python Package Build Test / build (3.13) (push) Failing after 1s
Integration Tests (Replay) / Integration Tests (, , , client=, ) (push) Failing after 2s
Unit Tests / unit-tests (3.13) (push) Failing after 3s
Update ReadTheDocs / update-readthedocs (push) Failing after 3s
Test External API and Providers / test-external (venv) (push) Failing after 4s
Vector IO Integration Tests / test-matrix (push) Failing after 4s
UI Tests / ui-tests (22) (push) Successful in 35s
API Conformance Tests / check-schema-compatibility (push) Successful in 6s
Unit Tests / unit-tests (3.12) (push) Failing after 3s
Pre-commit / pre-commit (push) Successful in 1m19s

# What does this PR do?

This PR is generated with AI and reviewed by me.

Refactors the AuthorizedSqlStore class to store the access policy as an
instance variable rather than passing it as a parameter to each method
call. This simplifies the API.

# Test Plan

existing tests
This commit is contained in:
ehhuang 2025-09-19 16:13:56 -07:00 committed by GitHub
parent d3600b92d1
commit f44eb935c4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 32 additions and 37 deletions

View file

@ -44,7 +44,7 @@ class LocalfsFilesImpl(Files):
storage_path.mkdir(parents=True, exist_ok=True) storage_path.mkdir(parents=True, exist_ok=True)
# Initialize SQL store for metadata # Initialize SQL store for metadata
self.sql_store = AuthorizedSqlStore(sqlstore_impl(self.config.metadata_store)) self.sql_store = AuthorizedSqlStore(sqlstore_impl(self.config.metadata_store), self.policy)
await self.sql_store.create_table( await self.sql_store.create_table(
"openai_files", "openai_files",
{ {
@ -74,7 +74,7 @@ class LocalfsFilesImpl(Files):
if not self.sql_store: if not self.sql_store:
raise RuntimeError("Files provider not initialized") raise RuntimeError("Files provider not initialized")
row = await self.sql_store.fetch_one("openai_files", policy=self.policy, where={"id": file_id}) row = await self.sql_store.fetch_one("openai_files", where={"id": file_id})
if not row: if not row:
raise ResourceNotFoundError(file_id, "File", "client.files.list()") raise ResourceNotFoundError(file_id, "File", "client.files.list()")
@ -150,7 +150,6 @@ class LocalfsFilesImpl(Files):
paginated_result = await self.sql_store.fetch_all( paginated_result = await self.sql_store.fetch_all(
table="openai_files", table="openai_files",
policy=self.policy,
where=where_conditions if where_conditions else None, where=where_conditions if where_conditions else None,
order_by=[("created_at", order.value)], order_by=[("created_at", order.value)],
cursor=("id", after) if after else None, cursor=("id", after) if after else None,

View file

@ -137,7 +137,7 @@ class S3FilesImpl(Files):
where: dict[str, str | dict] = {"id": file_id} where: dict[str, str | dict] = {"id": file_id}
if not return_expired: if not return_expired:
where["expires_at"] = {">": self._now()} where["expires_at"] = {">": self._now()}
if not (row := await self.sql_store.fetch_one("openai_files", policy=self.policy, where=where)): if not (row := await self.sql_store.fetch_one("openai_files", where=where)):
raise ResourceNotFoundError(file_id, "File", "files.list()") raise ResourceNotFoundError(file_id, "File", "files.list()")
return row return row
@ -164,7 +164,7 @@ class S3FilesImpl(Files):
self._client = _create_s3_client(self._config) self._client = _create_s3_client(self._config)
await _create_bucket_if_not_exists(self._client, self._config) await _create_bucket_if_not_exists(self._client, self._config)
self._sql_store = AuthorizedSqlStore(sqlstore_impl(self._config.metadata_store)) self._sql_store = AuthorizedSqlStore(sqlstore_impl(self._config.metadata_store), self.policy)
await self._sql_store.create_table( await self._sql_store.create_table(
"openai_files", "openai_files",
{ {
@ -268,7 +268,6 @@ class S3FilesImpl(Files):
paginated_result = await self.sql_store.fetch_all( paginated_result = await self.sql_store.fetch_all(
table="openai_files", table="openai_files",
policy=self.policy,
where=where_conditions, where=where_conditions,
order_by=[("created_at", order.value)], order_by=[("created_at", order.value)],
cursor=("id", after) if after else None, cursor=("id", after) if after else None,

View file

@ -54,7 +54,7 @@ class InferenceStore:
async def initialize(self): async def initialize(self):
"""Create the necessary tables if they don't exist.""" """Create the necessary tables if they don't exist."""
self.sql_store = AuthorizedSqlStore(sqlstore_impl(self.sql_store_config)) self.sql_store = AuthorizedSqlStore(sqlstore_impl(self.sql_store_config), self.policy)
await self.sql_store.create_table( await self.sql_store.create_table(
"chat_completions", "chat_completions",
{ {
@ -202,7 +202,6 @@ class InferenceStore:
order_by=[("created", order.value)], order_by=[("created", order.value)],
cursor=("id", after) if after else None, cursor=("id", after) if after else None,
limit=limit, limit=limit,
policy=self.policy,
) )
data = [ data = [
@ -229,7 +228,6 @@ class InferenceStore:
row = await self.sql_store.fetch_one( row = await self.sql_store.fetch_one(
table="chat_completions", table="chat_completions",
where={"id": completion_id}, where={"id": completion_id},
policy=self.policy,
) )
if not row: if not row:

View file

@ -28,8 +28,7 @@ class ResponsesStore:
sql_store_config = SqliteSqlStoreConfig( sql_store_config = SqliteSqlStoreConfig(
db_path=(RUNTIME_BASE_DIR / "sqlstore.db").as_posix(), db_path=(RUNTIME_BASE_DIR / "sqlstore.db").as_posix(),
) )
self.sql_store = AuthorizedSqlStore(sqlstore_impl(sql_store_config)) self.sql_store = AuthorizedSqlStore(sqlstore_impl(sql_store_config), policy)
self.policy = policy
async def initialize(self): async def initialize(self):
"""Create the necessary tables if they don't exist.""" """Create the necessary tables if they don't exist."""
@ -87,7 +86,6 @@ class ResponsesStore:
order_by=[("created_at", order.value)], order_by=[("created_at", order.value)],
cursor=("id", after) if after else None, cursor=("id", after) if after else None,
limit=limit, limit=limit,
policy=self.policy,
) )
data = [OpenAIResponseObjectWithInput(**row["response_object"]) for row in paginated_result.data] data = [OpenAIResponseObjectWithInput(**row["response_object"]) for row in paginated_result.data]
@ -105,7 +103,6 @@ class ResponsesStore:
row = await self.sql_store.fetch_one( row = await self.sql_store.fetch_one(
"openai_responses", "openai_responses",
where={"id": response_id}, where={"id": response_id},
policy=self.policy,
) )
if not row: if not row:
@ -116,7 +113,7 @@ class ResponsesStore:
return OpenAIResponseObjectWithInput(**row["response_object"]) return OpenAIResponseObjectWithInput(**row["response_object"])
async def delete_response_object(self, response_id: str) -> OpenAIDeleteResponseObject: async def delete_response_object(self, response_id: str) -> OpenAIDeleteResponseObject:
row = await self.sql_store.fetch_one("openai_responses", where={"id": response_id}, policy=self.policy) row = await self.sql_store.fetch_one("openai_responses", where={"id": response_id})
if not row: if not row:
raise ValueError(f"Response with id {response_id} not found") raise ValueError(f"Response with id {response_id} not found")
await self.sql_store.delete("openai_responses", where={"id": response_id}) await self.sql_store.delete("openai_responses", where={"id": response_id})

View file

@ -53,13 +53,15 @@ class AuthorizedSqlStore:
access control policies, user attribute capture, and SQL filtering optimization. access control policies, user attribute capture, and SQL filtering optimization.
""" """
def __init__(self, sql_store: SqlStore): def __init__(self, sql_store: SqlStore, policy: list[AccessRule]):
""" """
Initialize the authorization layer. Initialize the authorization layer.
:param sql_store: Base SqlStore implementation to wrap :param sql_store: Base SqlStore implementation to wrap
:param policy: Access control policy to use for authorization
""" """
self.sql_store = sql_store self.sql_store = sql_store
self.policy = policy
self._detect_database_type() self._detect_database_type()
self._validate_sql_optimized_policy() self._validate_sql_optimized_policy()
@ -117,14 +119,13 @@ class AuthorizedSqlStore:
async def fetch_all( async def fetch_all(
self, self,
table: str, table: str,
policy: list[AccessRule],
where: Mapping[str, Any] | None = None, where: Mapping[str, Any] | None = None,
limit: int | None = None, limit: int | None = None,
order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None, order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None,
cursor: tuple[str, str] | None = None, cursor: tuple[str, str] | None = None,
) -> PaginatedResponse: ) -> PaginatedResponse:
"""Fetch all rows with automatic access control filtering.""" """Fetch all rows with automatic access control filtering."""
access_where = self._build_access_control_where_clause(policy) access_where = self._build_access_control_where_clause(self.policy)
rows = await self.sql_store.fetch_all( rows = await self.sql_store.fetch_all(
table=table, table=table,
where=where, where=where,
@ -146,7 +147,7 @@ class AuthorizedSqlStore:
str(record_id), table, User(principal=stored_owner_principal, attributes=stored_access_attrs) str(record_id), table, User(principal=stored_owner_principal, attributes=stored_access_attrs)
) )
if is_action_allowed(policy, Action.READ, sql_record, current_user): if is_action_allowed(self.policy, Action.READ, sql_record, current_user):
filtered_rows.append(row) filtered_rows.append(row)
return PaginatedResponse( return PaginatedResponse(
@ -157,14 +158,12 @@ class AuthorizedSqlStore:
async def fetch_one( async def fetch_one(
self, self,
table: str, table: str,
policy: list[AccessRule],
where: Mapping[str, Any] | None = None, where: Mapping[str, Any] | None = None,
order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None, order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None,
) -> dict[str, Any] | None: ) -> dict[str, Any] | None:
"""Fetch one row with automatic access control checking.""" """Fetch one row with automatic access control checking."""
results = await self.fetch_all( results = await self.fetch_all(
table=table, table=table,
policy=policy,
where=where, where=where,
limit=1, limit=1,
order_by=order_by, order_by=order_by,

View file

@ -57,7 +57,7 @@ def authorized_store(backend_config):
config = config_func() config = config_func()
base_sqlstore = sqlstore_impl(config) base_sqlstore = sqlstore_impl(config)
authorized_store = AuthorizedSqlStore(base_sqlstore) authorized_store = AuthorizedSqlStore(base_sqlstore, default_policy())
yield authorized_store yield authorized_store
@ -106,7 +106,7 @@ async def test_authorized_store_attributes(mock_get_authenticated_user, authoriz
await authorized_store.insert(table_name, {"id": "1", "data": "public_data"}) await authorized_store.insert(table_name, {"id": "1", "data": "public_data"})
# Test fetching with no user - should not error on JSON comparison # Test fetching with no user - should not error on JSON comparison
result = await authorized_store.fetch_all(table_name, policy=default_policy()) result = await authorized_store.fetch_all(table_name)
assert len(result.data) == 1 assert len(result.data) == 1
assert result.data[0]["id"] == "1" assert result.data[0]["id"] == "1"
assert result.data[0]["access_attributes"] is None assert result.data[0]["access_attributes"] is None
@ -119,7 +119,7 @@ async def test_authorized_store_attributes(mock_get_authenticated_user, authoriz
await authorized_store.insert(table_name, {"id": "2", "data": "admin_data"}) await authorized_store.insert(table_name, {"id": "2", "data": "admin_data"})
# Fetch all - admin should see both # Fetch all - admin should see both
result = await authorized_store.fetch_all(table_name, policy=default_policy()) result = await authorized_store.fetch_all(table_name)
assert len(result.data) == 2 assert len(result.data) == 2
# Test with non-admin user # Test with non-admin user
@ -127,7 +127,7 @@ async def test_authorized_store_attributes(mock_get_authenticated_user, authoriz
mock_get_authenticated_user.return_value = regular_user mock_get_authenticated_user.return_value = regular_user
# Should only see public record # Should only see public record
result = await authorized_store.fetch_all(table_name, policy=default_policy()) result = await authorized_store.fetch_all(table_name)
assert len(result.data) == 1 assert len(result.data) == 1
assert result.data[0]["id"] == "1" assert result.data[0]["id"] == "1"
@ -156,7 +156,7 @@ async def test_authorized_store_attributes(mock_get_authenticated_user, authoriz
# Now test with the multi-user who has both roles=admin and teams=dev # Now test with the multi-user who has both roles=admin and teams=dev
mock_get_authenticated_user.return_value = multi_user mock_get_authenticated_user.return_value = multi_user
result = await authorized_store.fetch_all(table_name, policy=default_policy()) result = await authorized_store.fetch_all(table_name)
# Should see: # Should see:
# - public record (1) - no access_attributes # - public record (1) - no access_attributes
@ -217,21 +217,24 @@ async def test_user_ownership_policy(mock_get_authenticated_user, authorized_sto
), ),
] ]
# Create a new authorized store with the owner-only policy
owner_only_store = AuthorizedSqlStore(authorized_store.sql_store, owner_only_policy)
# Test user1 access - should only see their own record # Test user1 access - should only see their own record
mock_get_authenticated_user.return_value = user1 mock_get_authenticated_user.return_value = user1
result = await authorized_store.fetch_all(table_name, policy=owner_only_policy) result = await owner_only_store.fetch_all(table_name)
assert len(result.data) == 1, f"Expected user1 to see 1 record, got {len(result.data)}" assert len(result.data) == 1, f"Expected user1 to see 1 record, got {len(result.data)}"
assert result.data[0]["id"] == "1", f"Expected user1's record, got {result.data[0]['id']}" assert result.data[0]["id"] == "1", f"Expected user1's record, got {result.data[0]['id']}"
# Test user2 access - should only see their own record # Test user2 access - should only see their own record
mock_get_authenticated_user.return_value = user2 mock_get_authenticated_user.return_value = user2
result = await authorized_store.fetch_all(table_name, policy=owner_only_policy) result = await owner_only_store.fetch_all(table_name)
assert len(result.data) == 1, f"Expected user2 to see 1 record, got {len(result.data)}" assert len(result.data) == 1, f"Expected user2 to see 1 record, got {len(result.data)}"
assert result.data[0]["id"] == "2", f"Expected user2's record, got {result.data[0]['id']}" assert result.data[0]["id"] == "2", f"Expected user2's record, got {result.data[0]['id']}"
# Test with anonymous user - should see no records # Test with anonymous user - should see no records
mock_get_authenticated_user.return_value = None mock_get_authenticated_user.return_value = None
result = await authorized_store.fetch_all(table_name, policy=owner_only_policy) result = await owner_only_store.fetch_all(table_name)
assert len(result.data) == 0, f"Expected anonymous user to see 0 records, got {len(result.data)}" assert len(result.data) == 0, f"Expected anonymous user to see 0 records, got {len(result.data)}"
finally: finally:

View file

@ -26,7 +26,7 @@ async def test_authorized_fetch_with_where_sql_access_control(mock_get_authentic
db_path=tmp_dir + "/" + db_name, db_path=tmp_dir + "/" + db_name,
) )
) )
sqlstore = AuthorizedSqlStore(base_sqlstore) sqlstore = AuthorizedSqlStore(base_sqlstore, default_policy())
# Create table with access control # Create table with access control
await sqlstore.create_table( await sqlstore.create_table(
@ -56,24 +56,24 @@ async def test_authorized_fetch_with_where_sql_access_control(mock_get_authentic
mock_get_authenticated_user.return_value = admin_user mock_get_authenticated_user.return_value = admin_user
# Admin should see both documents # Admin should see both documents
result = await sqlstore.fetch_all("documents", policy=default_policy(), where={"id": 1}) result = await sqlstore.fetch_all("documents", where={"id": 1})
assert len(result.data) == 1 assert len(result.data) == 1
assert result.data[0]["title"] == "Admin Document" assert result.data[0]["title"] == "Admin Document"
# User should only see their document # User should only see their document
mock_get_authenticated_user.return_value = regular_user mock_get_authenticated_user.return_value = regular_user
result = await sqlstore.fetch_all("documents", policy=default_policy(), where={"id": 1}) result = await sqlstore.fetch_all("documents", where={"id": 1})
assert len(result.data) == 0 assert len(result.data) == 0
result = await sqlstore.fetch_all("documents", policy=default_policy(), where={"id": 2}) result = await sqlstore.fetch_all("documents", where={"id": 2})
assert len(result.data) == 1 assert len(result.data) == 1
assert result.data[0]["title"] == "User Document" assert result.data[0]["title"] == "User Document"
row = await sqlstore.fetch_one("documents", policy=default_policy(), where={"id": 1}) row = await sqlstore.fetch_one("documents", where={"id": 1})
assert row is None assert row is None
row = await sqlstore.fetch_one("documents", policy=default_policy(), where={"id": 2}) row = await sqlstore.fetch_one("documents", where={"id": 2})
assert row is not None assert row is not None
assert row["title"] == "User Document" assert row["title"] == "User Document"
@ -88,7 +88,7 @@ async def test_sql_policy_consistency(mock_get_authenticated_user):
db_path=tmp_dir + "/" + db_name, db_path=tmp_dir + "/" + db_name,
) )
) )
sqlstore = AuthorizedSqlStore(base_sqlstore) sqlstore = AuthorizedSqlStore(base_sqlstore, default_policy())
await sqlstore.create_table( await sqlstore.create_table(
table="resources", table="resources",
@ -144,7 +144,7 @@ async def test_sql_policy_consistency(mock_get_authenticated_user):
user = User(principal=user_data["principal"], attributes=user_data["attributes"]) user = User(principal=user_data["principal"], attributes=user_data["attributes"])
mock_get_authenticated_user.return_value = user mock_get_authenticated_user.return_value = user
sql_results = await sqlstore.fetch_all("resources", policy=policy) sql_results = await sqlstore.fetch_all("resources")
sql_ids = {row["id"] for row in sql_results.data} sql_ids = {row["id"] for row in sql_results.data}
policy_ids = set() policy_ids = set()
for scenario in test_scenarios: for scenario in test_scenarios:
@ -174,7 +174,7 @@ async def test_authorized_store_user_attribute_capture(mock_get_authenticated_us
db_path=tmp_dir + "/" + db_name, db_path=tmp_dir + "/" + db_name,
) )
) )
authorized_store = AuthorizedSqlStore(base_sqlstore) authorized_store = AuthorizedSqlStore(base_sqlstore, default_policy())
await authorized_store.create_table( await authorized_store.create_table(
table="user_data", table="user_data",