chore: simplify authorized sqlstore

# What does this PR do?


## Test Plan
This commit is contained in:
Eric Huang 2025-09-19 14:59:30 -07:00
parent d3600b92d1
commit b4974d411d
7 changed files with 32 additions and 37 deletions

View file

@ -44,7 +44,7 @@ class LocalfsFilesImpl(Files):
storage_path.mkdir(parents=True, exist_ok=True) storage_path.mkdir(parents=True, exist_ok=True)
# Initialize SQL store for metadata # Initialize SQL store for metadata
self.sql_store = AuthorizedSqlStore(sqlstore_impl(self.config.metadata_store)) self.sql_store = AuthorizedSqlStore(sqlstore_impl(self.config.metadata_store), self.policy)
await self.sql_store.create_table( await self.sql_store.create_table(
"openai_files", "openai_files",
{ {
@ -74,7 +74,7 @@ class LocalfsFilesImpl(Files):
if not self.sql_store: if not self.sql_store:
raise RuntimeError("Files provider not initialized") raise RuntimeError("Files provider not initialized")
row = await self.sql_store.fetch_one("openai_files", policy=self.policy, where={"id": file_id}) row = await self.sql_store.fetch_one("openai_files", where={"id": file_id})
if not row: if not row:
raise ResourceNotFoundError(file_id, "File", "client.files.list()") raise ResourceNotFoundError(file_id, "File", "client.files.list()")
@ -150,7 +150,6 @@ class LocalfsFilesImpl(Files):
paginated_result = await self.sql_store.fetch_all( paginated_result = await self.sql_store.fetch_all(
table="openai_files", table="openai_files",
policy=self.policy,
where=where_conditions if where_conditions else None, where=where_conditions if where_conditions else None,
order_by=[("created_at", order.value)], order_by=[("created_at", order.value)],
cursor=("id", after) if after else None, cursor=("id", after) if after else None,

View file

@ -137,7 +137,7 @@ class S3FilesImpl(Files):
where: dict[str, str | dict] = {"id": file_id} where: dict[str, str | dict] = {"id": file_id}
if not return_expired: if not return_expired:
where["expires_at"] = {">": self._now()} where["expires_at"] = {">": self._now()}
if not (row := await self.sql_store.fetch_one("openai_files", policy=self.policy, where=where)): if not (row := await self.sql_store.fetch_one("openai_files", where=where)):
raise ResourceNotFoundError(file_id, "File", "files.list()") raise ResourceNotFoundError(file_id, "File", "files.list()")
return row return row
@ -164,7 +164,7 @@ class S3FilesImpl(Files):
self._client = _create_s3_client(self._config) self._client = _create_s3_client(self._config)
await _create_bucket_if_not_exists(self._client, self._config) await _create_bucket_if_not_exists(self._client, self._config)
self._sql_store = AuthorizedSqlStore(sqlstore_impl(self._config.metadata_store)) self._sql_store = AuthorizedSqlStore(sqlstore_impl(self._config.metadata_store), self.policy)
await self._sql_store.create_table( await self._sql_store.create_table(
"openai_files", "openai_files",
{ {
@ -268,7 +268,6 @@ class S3FilesImpl(Files):
paginated_result = await self.sql_store.fetch_all( paginated_result = await self.sql_store.fetch_all(
table="openai_files", table="openai_files",
policy=self.policy,
where=where_conditions, where=where_conditions,
order_by=[("created_at", order.value)], order_by=[("created_at", order.value)],
cursor=("id", after) if after else None, cursor=("id", after) if after else None,

View file

@ -54,7 +54,7 @@ class InferenceStore:
async def initialize(self): async def initialize(self):
"""Create the necessary tables if they don't exist.""" """Create the necessary tables if they don't exist."""
self.sql_store = AuthorizedSqlStore(sqlstore_impl(self.sql_store_config)) self.sql_store = AuthorizedSqlStore(sqlstore_impl(self.sql_store_config), self.policy)
await self.sql_store.create_table( await self.sql_store.create_table(
"chat_completions", "chat_completions",
{ {
@ -202,7 +202,6 @@ class InferenceStore:
order_by=[("created", order.value)], order_by=[("created", order.value)],
cursor=("id", after) if after else None, cursor=("id", after) if after else None,
limit=limit, limit=limit,
policy=self.policy,
) )
data = [ data = [
@ -229,7 +228,6 @@ class InferenceStore:
row = await self.sql_store.fetch_one( row = await self.sql_store.fetch_one(
table="chat_completions", table="chat_completions",
where={"id": completion_id}, where={"id": completion_id},
policy=self.policy,
) )
if not row: if not row:

View file

@ -28,8 +28,7 @@ class ResponsesStore:
sql_store_config = SqliteSqlStoreConfig( sql_store_config = SqliteSqlStoreConfig(
db_path=(RUNTIME_BASE_DIR / "sqlstore.db").as_posix(), db_path=(RUNTIME_BASE_DIR / "sqlstore.db").as_posix(),
) )
self.sql_store = AuthorizedSqlStore(sqlstore_impl(sql_store_config)) self.sql_store = AuthorizedSqlStore(sqlstore_impl(sql_store_config), policy)
self.policy = policy
async def initialize(self): async def initialize(self):
"""Create the necessary tables if they don't exist.""" """Create the necessary tables if they don't exist."""
@ -87,7 +86,6 @@ class ResponsesStore:
order_by=[("created_at", order.value)], order_by=[("created_at", order.value)],
cursor=("id", after) if after else None, cursor=("id", after) if after else None,
limit=limit, limit=limit,
policy=self.policy,
) )
data = [OpenAIResponseObjectWithInput(**row["response_object"]) for row in paginated_result.data] data = [OpenAIResponseObjectWithInput(**row["response_object"]) for row in paginated_result.data]
@ -105,7 +103,6 @@ class ResponsesStore:
row = await self.sql_store.fetch_one( row = await self.sql_store.fetch_one(
"openai_responses", "openai_responses",
where={"id": response_id}, where={"id": response_id},
policy=self.policy,
) )
if not row: if not row:
@ -116,7 +113,7 @@ class ResponsesStore:
return OpenAIResponseObjectWithInput(**row["response_object"]) return OpenAIResponseObjectWithInput(**row["response_object"])
async def delete_response_object(self, response_id: str) -> OpenAIDeleteResponseObject: async def delete_response_object(self, response_id: str) -> OpenAIDeleteResponseObject:
row = await self.sql_store.fetch_one("openai_responses", where={"id": response_id}, policy=self.policy) row = await self.sql_store.fetch_one("openai_responses", where={"id": response_id})
if not row: if not row:
raise ValueError(f"Response with id {response_id} not found") raise ValueError(f"Response with id {response_id} not found")
await self.sql_store.delete("openai_responses", where={"id": response_id}) await self.sql_store.delete("openai_responses", where={"id": response_id})

View file

@ -53,13 +53,15 @@ class AuthorizedSqlStore:
access control policies, user attribute capture, and SQL filtering optimization. access control policies, user attribute capture, and SQL filtering optimization.
""" """
def __init__(self, sql_store: SqlStore): def __init__(self, sql_store: SqlStore, policy: list[AccessRule]):
""" """
Initialize the authorization layer. Initialize the authorization layer.
:param sql_store: Base SqlStore implementation to wrap :param sql_store: Base SqlStore implementation to wrap
:param policy: Access control policy to use for authorization
""" """
self.sql_store = sql_store self.sql_store = sql_store
self.policy = policy
self._detect_database_type() self._detect_database_type()
self._validate_sql_optimized_policy() self._validate_sql_optimized_policy()
@ -117,14 +119,13 @@ class AuthorizedSqlStore:
async def fetch_all( async def fetch_all(
self, self,
table: str, table: str,
policy: list[AccessRule],
where: Mapping[str, Any] | None = None, where: Mapping[str, Any] | None = None,
limit: int | None = None, limit: int | None = None,
order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None, order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None,
cursor: tuple[str, str] | None = None, cursor: tuple[str, str] | None = None,
) -> PaginatedResponse: ) -> PaginatedResponse:
"""Fetch all rows with automatic access control filtering.""" """Fetch all rows with automatic access control filtering."""
access_where = self._build_access_control_where_clause(policy) access_where = self._build_access_control_where_clause(self.policy)
rows = await self.sql_store.fetch_all( rows = await self.sql_store.fetch_all(
table=table, table=table,
where=where, where=where,
@ -146,7 +147,7 @@ class AuthorizedSqlStore:
str(record_id), table, User(principal=stored_owner_principal, attributes=stored_access_attrs) str(record_id), table, User(principal=stored_owner_principal, attributes=stored_access_attrs)
) )
if is_action_allowed(policy, Action.READ, sql_record, current_user): if is_action_allowed(self.policy, Action.READ, sql_record, current_user):
filtered_rows.append(row) filtered_rows.append(row)
return PaginatedResponse( return PaginatedResponse(
@ -157,14 +158,12 @@ class AuthorizedSqlStore:
async def fetch_one( async def fetch_one(
self, self,
table: str, table: str,
policy: list[AccessRule],
where: Mapping[str, Any] | None = None, where: Mapping[str, Any] | None = None,
order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None, order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None,
) -> dict[str, Any] | None: ) -> dict[str, Any] | None:
"""Fetch one row with automatic access control checking.""" """Fetch one row with automatic access control checking."""
results = await self.fetch_all( results = await self.fetch_all(
table=table, table=table,
policy=policy,
where=where, where=where,
limit=1, limit=1,
order_by=order_by, order_by=order_by,

View file

@ -57,7 +57,7 @@ def authorized_store(backend_config):
config = config_func() config = config_func()
base_sqlstore = sqlstore_impl(config) base_sqlstore = sqlstore_impl(config)
authorized_store = AuthorizedSqlStore(base_sqlstore) authorized_store = AuthorizedSqlStore(base_sqlstore, default_policy())
yield authorized_store yield authorized_store
@ -106,7 +106,7 @@ async def test_authorized_store_attributes(mock_get_authenticated_user, authoriz
await authorized_store.insert(table_name, {"id": "1", "data": "public_data"}) await authorized_store.insert(table_name, {"id": "1", "data": "public_data"})
# Test fetching with no user - should not error on JSON comparison # Test fetching with no user - should not error on JSON comparison
result = await authorized_store.fetch_all(table_name, policy=default_policy()) result = await authorized_store.fetch_all(table_name)
assert len(result.data) == 1 assert len(result.data) == 1
assert result.data[0]["id"] == "1" assert result.data[0]["id"] == "1"
assert result.data[0]["access_attributes"] is None assert result.data[0]["access_attributes"] is None
@ -119,7 +119,7 @@ async def test_authorized_store_attributes(mock_get_authenticated_user, authoriz
await authorized_store.insert(table_name, {"id": "2", "data": "admin_data"}) await authorized_store.insert(table_name, {"id": "2", "data": "admin_data"})
# Fetch all - admin should see both # Fetch all - admin should see both
result = await authorized_store.fetch_all(table_name, policy=default_policy()) result = await authorized_store.fetch_all(table_name)
assert len(result.data) == 2 assert len(result.data) == 2
# Test with non-admin user # Test with non-admin user
@ -127,7 +127,7 @@ async def test_authorized_store_attributes(mock_get_authenticated_user, authoriz
mock_get_authenticated_user.return_value = regular_user mock_get_authenticated_user.return_value = regular_user
# Should only see public record # Should only see public record
result = await authorized_store.fetch_all(table_name, policy=default_policy()) result = await authorized_store.fetch_all(table_name)
assert len(result.data) == 1 assert len(result.data) == 1
assert result.data[0]["id"] == "1" assert result.data[0]["id"] == "1"
@ -156,7 +156,7 @@ async def test_authorized_store_attributes(mock_get_authenticated_user, authoriz
# Now test with the multi-user who has both roles=admin and teams=dev # Now test with the multi-user who has both roles=admin and teams=dev
mock_get_authenticated_user.return_value = multi_user mock_get_authenticated_user.return_value = multi_user
result = await authorized_store.fetch_all(table_name, policy=default_policy()) result = await authorized_store.fetch_all(table_name)
# Should see: # Should see:
# - public record (1) - no access_attributes # - public record (1) - no access_attributes
@ -217,21 +217,24 @@ async def test_user_ownership_policy(mock_get_authenticated_user, authorized_sto
), ),
] ]
# Create a new authorized store with the owner-only policy
owner_only_store = AuthorizedSqlStore(authorized_store.sql_store, owner_only_policy)
# Test user1 access - should only see their own record # Test user1 access - should only see their own record
mock_get_authenticated_user.return_value = user1 mock_get_authenticated_user.return_value = user1
result = await authorized_store.fetch_all(table_name, policy=owner_only_policy) result = await owner_only_store.fetch_all(table_name)
assert len(result.data) == 1, f"Expected user1 to see 1 record, got {len(result.data)}" assert len(result.data) == 1, f"Expected user1 to see 1 record, got {len(result.data)}"
assert result.data[0]["id"] == "1", f"Expected user1's record, got {result.data[0]['id']}" assert result.data[0]["id"] == "1", f"Expected user1's record, got {result.data[0]['id']}"
# Test user2 access - should only see their own record # Test user2 access - should only see their own record
mock_get_authenticated_user.return_value = user2 mock_get_authenticated_user.return_value = user2
result = await authorized_store.fetch_all(table_name, policy=owner_only_policy) result = await owner_only_store.fetch_all(table_name)
assert len(result.data) == 1, f"Expected user2 to see 1 record, got {len(result.data)}" assert len(result.data) == 1, f"Expected user2 to see 1 record, got {len(result.data)}"
assert result.data[0]["id"] == "2", f"Expected user2's record, got {result.data[0]['id']}" assert result.data[0]["id"] == "2", f"Expected user2's record, got {result.data[0]['id']}"
# Test with anonymous user - should see no records # Test with anonymous user - should see no records
mock_get_authenticated_user.return_value = None mock_get_authenticated_user.return_value = None
result = await authorized_store.fetch_all(table_name, policy=owner_only_policy) result = await owner_only_store.fetch_all(table_name)
assert len(result.data) == 0, f"Expected anonymous user to see 0 records, got {len(result.data)}" assert len(result.data) == 0, f"Expected anonymous user to see 0 records, got {len(result.data)}"
finally: finally:

View file

@ -26,7 +26,7 @@ async def test_authorized_fetch_with_where_sql_access_control(mock_get_authentic
db_path=tmp_dir + "/" + db_name, db_path=tmp_dir + "/" + db_name,
) )
) )
sqlstore = AuthorizedSqlStore(base_sqlstore) sqlstore = AuthorizedSqlStore(base_sqlstore, default_policy())
# Create table with access control # Create table with access control
await sqlstore.create_table( await sqlstore.create_table(
@ -56,24 +56,24 @@ async def test_authorized_fetch_with_where_sql_access_control(mock_get_authentic
mock_get_authenticated_user.return_value = admin_user mock_get_authenticated_user.return_value = admin_user
# Admin should see both documents # Admin should see both documents
result = await sqlstore.fetch_all("documents", policy=default_policy(), where={"id": 1}) result = await sqlstore.fetch_all("documents", where={"id": 1})
assert len(result.data) == 1 assert len(result.data) == 1
assert result.data[0]["title"] == "Admin Document" assert result.data[0]["title"] == "Admin Document"
# User should only see their document # User should only see their document
mock_get_authenticated_user.return_value = regular_user mock_get_authenticated_user.return_value = regular_user
result = await sqlstore.fetch_all("documents", policy=default_policy(), where={"id": 1}) result = await sqlstore.fetch_all("documents", where={"id": 1})
assert len(result.data) == 0 assert len(result.data) == 0
result = await sqlstore.fetch_all("documents", policy=default_policy(), where={"id": 2}) result = await sqlstore.fetch_all("documents", where={"id": 2})
assert len(result.data) == 1 assert len(result.data) == 1
assert result.data[0]["title"] == "User Document" assert result.data[0]["title"] == "User Document"
row = await sqlstore.fetch_one("documents", policy=default_policy(), where={"id": 1}) row = await sqlstore.fetch_one("documents", where={"id": 1})
assert row is None assert row is None
row = await sqlstore.fetch_one("documents", policy=default_policy(), where={"id": 2}) row = await sqlstore.fetch_one("documents", where={"id": 2})
assert row is not None assert row is not None
assert row["title"] == "User Document" assert row["title"] == "User Document"
@ -88,7 +88,7 @@ async def test_sql_policy_consistency(mock_get_authenticated_user):
db_path=tmp_dir + "/" + db_name, db_path=tmp_dir + "/" + db_name,
) )
) )
sqlstore = AuthorizedSqlStore(base_sqlstore) sqlstore = AuthorizedSqlStore(base_sqlstore, default_policy())
await sqlstore.create_table( await sqlstore.create_table(
table="resources", table="resources",
@ -144,7 +144,7 @@ async def test_sql_policy_consistency(mock_get_authenticated_user):
user = User(principal=user_data["principal"], attributes=user_data["attributes"]) user = User(principal=user_data["principal"], attributes=user_data["attributes"])
mock_get_authenticated_user.return_value = user mock_get_authenticated_user.return_value = user
sql_results = await sqlstore.fetch_all("resources", policy=policy) sql_results = await sqlstore.fetch_all("resources")
sql_ids = {row["id"] for row in sql_results.data} sql_ids = {row["id"] for row in sql_results.data}
policy_ids = set() policy_ids = set()
for scenario in test_scenarios: for scenario in test_scenarios:
@ -174,7 +174,7 @@ async def test_authorized_store_user_attribute_capture(mock_get_authenticated_us
db_path=tmp_dir + "/" + db_name, db_path=tmp_dir + "/" + db_name,
) )
) )
authorized_store = AuthorizedSqlStore(base_sqlstore) authorized_store = AuthorizedSqlStore(base_sqlstore, default_policy())
await authorized_store.create_table( await authorized_store.create_table(
table="user_data", table="user_data",