mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 04:04:14 +00:00
chore: simplify authorized sqlstore (#3496)
Some checks failed
SqlStore Integration Tests / test-postgres (3.12) (push) Failing after 1s
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 1s
SqlStore Integration Tests / test-postgres (3.13) (push) Failing after 0s
Test External Providers Installed via Module / test-external-providers-from-module (venv) (push) Has been skipped
Python Package Build Test / build (3.12) (push) Failing after 1s
Python Package Build Test / build (3.13) (push) Failing after 1s
Integration Tests (Replay) / Integration Tests (, , , client=, ) (push) Failing after 2s
Unit Tests / unit-tests (3.13) (push) Failing after 3s
Update ReadTheDocs / update-readthedocs (push) Failing after 3s
Test External API and Providers / test-external (venv) (push) Failing after 4s
Vector IO Integration Tests / test-matrix (push) Failing after 4s
UI Tests / ui-tests (22) (push) Successful in 35s
API Conformance Tests / check-schema-compatibility (push) Successful in 6s
Unit Tests / unit-tests (3.12) (push) Failing after 3s
Pre-commit / pre-commit (push) Successful in 1m19s
Some checks failed
SqlStore Integration Tests / test-postgres (3.12) (push) Failing after 1s
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 1s
SqlStore Integration Tests / test-postgres (3.13) (push) Failing after 0s
Test External Providers Installed via Module / test-external-providers-from-module (venv) (push) Has been skipped
Python Package Build Test / build (3.12) (push) Failing after 1s
Python Package Build Test / build (3.13) (push) Failing after 1s
Integration Tests (Replay) / Integration Tests (, , , client=, ) (push) Failing after 2s
Unit Tests / unit-tests (3.13) (push) Failing after 3s
Update ReadTheDocs / update-readthedocs (push) Failing after 3s
Test External API and Providers / test-external (venv) (push) Failing after 4s
Vector IO Integration Tests / test-matrix (push) Failing after 4s
UI Tests / ui-tests (22) (push) Successful in 35s
API Conformance Tests / check-schema-compatibility (push) Successful in 6s
Unit Tests / unit-tests (3.12) (push) Failing after 3s
Pre-commit / pre-commit (push) Successful in 1m19s
# What does this PR do? This PR is generated with AI and reviewed by me. Refactors the AuthorizedSqlStore class to store the access policy as an instance variable rather than passing it as a parameter to each method call. This simplifies the API. # Test Plan existing tests
This commit is contained in:
parent
d3600b92d1
commit
f44eb935c4
7 changed files with 32 additions and 37 deletions
|
@ -44,7 +44,7 @@ class LocalfsFilesImpl(Files):
|
||||||
storage_path.mkdir(parents=True, exist_ok=True)
|
storage_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
# Initialize SQL store for metadata
|
# Initialize SQL store for metadata
|
||||||
self.sql_store = AuthorizedSqlStore(sqlstore_impl(self.config.metadata_store))
|
self.sql_store = AuthorizedSqlStore(sqlstore_impl(self.config.metadata_store), self.policy)
|
||||||
await self.sql_store.create_table(
|
await self.sql_store.create_table(
|
||||||
"openai_files",
|
"openai_files",
|
||||||
{
|
{
|
||||||
|
@ -74,7 +74,7 @@ class LocalfsFilesImpl(Files):
|
||||||
if not self.sql_store:
|
if not self.sql_store:
|
||||||
raise RuntimeError("Files provider not initialized")
|
raise RuntimeError("Files provider not initialized")
|
||||||
|
|
||||||
row = await self.sql_store.fetch_one("openai_files", policy=self.policy, where={"id": file_id})
|
row = await self.sql_store.fetch_one("openai_files", where={"id": file_id})
|
||||||
if not row:
|
if not row:
|
||||||
raise ResourceNotFoundError(file_id, "File", "client.files.list()")
|
raise ResourceNotFoundError(file_id, "File", "client.files.list()")
|
||||||
|
|
||||||
|
@ -150,7 +150,6 @@ class LocalfsFilesImpl(Files):
|
||||||
|
|
||||||
paginated_result = await self.sql_store.fetch_all(
|
paginated_result = await self.sql_store.fetch_all(
|
||||||
table="openai_files",
|
table="openai_files",
|
||||||
policy=self.policy,
|
|
||||||
where=where_conditions if where_conditions else None,
|
where=where_conditions if where_conditions else None,
|
||||||
order_by=[("created_at", order.value)],
|
order_by=[("created_at", order.value)],
|
||||||
cursor=("id", after) if after else None,
|
cursor=("id", after) if after else None,
|
||||||
|
|
|
@ -137,7 +137,7 @@ class S3FilesImpl(Files):
|
||||||
where: dict[str, str | dict] = {"id": file_id}
|
where: dict[str, str | dict] = {"id": file_id}
|
||||||
if not return_expired:
|
if not return_expired:
|
||||||
where["expires_at"] = {">": self._now()}
|
where["expires_at"] = {">": self._now()}
|
||||||
if not (row := await self.sql_store.fetch_one("openai_files", policy=self.policy, where=where)):
|
if not (row := await self.sql_store.fetch_one("openai_files", where=where)):
|
||||||
raise ResourceNotFoundError(file_id, "File", "files.list()")
|
raise ResourceNotFoundError(file_id, "File", "files.list()")
|
||||||
return row
|
return row
|
||||||
|
|
||||||
|
@ -164,7 +164,7 @@ class S3FilesImpl(Files):
|
||||||
self._client = _create_s3_client(self._config)
|
self._client = _create_s3_client(self._config)
|
||||||
await _create_bucket_if_not_exists(self._client, self._config)
|
await _create_bucket_if_not_exists(self._client, self._config)
|
||||||
|
|
||||||
self._sql_store = AuthorizedSqlStore(sqlstore_impl(self._config.metadata_store))
|
self._sql_store = AuthorizedSqlStore(sqlstore_impl(self._config.metadata_store), self.policy)
|
||||||
await self._sql_store.create_table(
|
await self._sql_store.create_table(
|
||||||
"openai_files",
|
"openai_files",
|
||||||
{
|
{
|
||||||
|
@ -268,7 +268,6 @@ class S3FilesImpl(Files):
|
||||||
|
|
||||||
paginated_result = await self.sql_store.fetch_all(
|
paginated_result = await self.sql_store.fetch_all(
|
||||||
table="openai_files",
|
table="openai_files",
|
||||||
policy=self.policy,
|
|
||||||
where=where_conditions,
|
where=where_conditions,
|
||||||
order_by=[("created_at", order.value)],
|
order_by=[("created_at", order.value)],
|
||||||
cursor=("id", after) if after else None,
|
cursor=("id", after) if after else None,
|
||||||
|
|
|
@ -54,7 +54,7 @@ class InferenceStore:
|
||||||
|
|
||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
"""Create the necessary tables if they don't exist."""
|
"""Create the necessary tables if they don't exist."""
|
||||||
self.sql_store = AuthorizedSqlStore(sqlstore_impl(self.sql_store_config))
|
self.sql_store = AuthorizedSqlStore(sqlstore_impl(self.sql_store_config), self.policy)
|
||||||
await self.sql_store.create_table(
|
await self.sql_store.create_table(
|
||||||
"chat_completions",
|
"chat_completions",
|
||||||
{
|
{
|
||||||
|
@ -202,7 +202,6 @@ class InferenceStore:
|
||||||
order_by=[("created", order.value)],
|
order_by=[("created", order.value)],
|
||||||
cursor=("id", after) if after else None,
|
cursor=("id", after) if after else None,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
policy=self.policy,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
data = [
|
data = [
|
||||||
|
@ -229,7 +228,6 @@ class InferenceStore:
|
||||||
row = await self.sql_store.fetch_one(
|
row = await self.sql_store.fetch_one(
|
||||||
table="chat_completions",
|
table="chat_completions",
|
||||||
where={"id": completion_id},
|
where={"id": completion_id},
|
||||||
policy=self.policy,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if not row:
|
if not row:
|
||||||
|
|
|
@ -28,8 +28,7 @@ class ResponsesStore:
|
||||||
sql_store_config = SqliteSqlStoreConfig(
|
sql_store_config = SqliteSqlStoreConfig(
|
||||||
db_path=(RUNTIME_BASE_DIR / "sqlstore.db").as_posix(),
|
db_path=(RUNTIME_BASE_DIR / "sqlstore.db").as_posix(),
|
||||||
)
|
)
|
||||||
self.sql_store = AuthorizedSqlStore(sqlstore_impl(sql_store_config))
|
self.sql_store = AuthorizedSqlStore(sqlstore_impl(sql_store_config), policy)
|
||||||
self.policy = policy
|
|
||||||
|
|
||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
"""Create the necessary tables if they don't exist."""
|
"""Create the necessary tables if they don't exist."""
|
||||||
|
@ -87,7 +86,6 @@ class ResponsesStore:
|
||||||
order_by=[("created_at", order.value)],
|
order_by=[("created_at", order.value)],
|
||||||
cursor=("id", after) if after else None,
|
cursor=("id", after) if after else None,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
policy=self.policy,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
data = [OpenAIResponseObjectWithInput(**row["response_object"]) for row in paginated_result.data]
|
data = [OpenAIResponseObjectWithInput(**row["response_object"]) for row in paginated_result.data]
|
||||||
|
@ -105,7 +103,6 @@ class ResponsesStore:
|
||||||
row = await self.sql_store.fetch_one(
|
row = await self.sql_store.fetch_one(
|
||||||
"openai_responses",
|
"openai_responses",
|
||||||
where={"id": response_id},
|
where={"id": response_id},
|
||||||
policy=self.policy,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if not row:
|
if not row:
|
||||||
|
@ -116,7 +113,7 @@ class ResponsesStore:
|
||||||
return OpenAIResponseObjectWithInput(**row["response_object"])
|
return OpenAIResponseObjectWithInput(**row["response_object"])
|
||||||
|
|
||||||
async def delete_response_object(self, response_id: str) -> OpenAIDeleteResponseObject:
|
async def delete_response_object(self, response_id: str) -> OpenAIDeleteResponseObject:
|
||||||
row = await self.sql_store.fetch_one("openai_responses", where={"id": response_id}, policy=self.policy)
|
row = await self.sql_store.fetch_one("openai_responses", where={"id": response_id})
|
||||||
if not row:
|
if not row:
|
||||||
raise ValueError(f"Response with id {response_id} not found")
|
raise ValueError(f"Response with id {response_id} not found")
|
||||||
await self.sql_store.delete("openai_responses", where={"id": response_id})
|
await self.sql_store.delete("openai_responses", where={"id": response_id})
|
||||||
|
|
|
@ -53,13 +53,15 @@ class AuthorizedSqlStore:
|
||||||
access control policies, user attribute capture, and SQL filtering optimization.
|
access control policies, user attribute capture, and SQL filtering optimization.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, sql_store: SqlStore):
|
def __init__(self, sql_store: SqlStore, policy: list[AccessRule]):
|
||||||
"""
|
"""
|
||||||
Initialize the authorization layer.
|
Initialize the authorization layer.
|
||||||
|
|
||||||
:param sql_store: Base SqlStore implementation to wrap
|
:param sql_store: Base SqlStore implementation to wrap
|
||||||
|
:param policy: Access control policy to use for authorization
|
||||||
"""
|
"""
|
||||||
self.sql_store = sql_store
|
self.sql_store = sql_store
|
||||||
|
self.policy = policy
|
||||||
self._detect_database_type()
|
self._detect_database_type()
|
||||||
self._validate_sql_optimized_policy()
|
self._validate_sql_optimized_policy()
|
||||||
|
|
||||||
|
@ -117,14 +119,13 @@ class AuthorizedSqlStore:
|
||||||
async def fetch_all(
|
async def fetch_all(
|
||||||
self,
|
self,
|
||||||
table: str,
|
table: str,
|
||||||
policy: list[AccessRule],
|
|
||||||
where: Mapping[str, Any] | None = None,
|
where: Mapping[str, Any] | None = None,
|
||||||
limit: int | None = None,
|
limit: int | None = None,
|
||||||
order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None,
|
order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None,
|
||||||
cursor: tuple[str, str] | None = None,
|
cursor: tuple[str, str] | None = None,
|
||||||
) -> PaginatedResponse:
|
) -> PaginatedResponse:
|
||||||
"""Fetch all rows with automatic access control filtering."""
|
"""Fetch all rows with automatic access control filtering."""
|
||||||
access_where = self._build_access_control_where_clause(policy)
|
access_where = self._build_access_control_where_clause(self.policy)
|
||||||
rows = await self.sql_store.fetch_all(
|
rows = await self.sql_store.fetch_all(
|
||||||
table=table,
|
table=table,
|
||||||
where=where,
|
where=where,
|
||||||
|
@ -146,7 +147,7 @@ class AuthorizedSqlStore:
|
||||||
str(record_id), table, User(principal=stored_owner_principal, attributes=stored_access_attrs)
|
str(record_id), table, User(principal=stored_owner_principal, attributes=stored_access_attrs)
|
||||||
)
|
)
|
||||||
|
|
||||||
if is_action_allowed(policy, Action.READ, sql_record, current_user):
|
if is_action_allowed(self.policy, Action.READ, sql_record, current_user):
|
||||||
filtered_rows.append(row)
|
filtered_rows.append(row)
|
||||||
|
|
||||||
return PaginatedResponse(
|
return PaginatedResponse(
|
||||||
|
@ -157,14 +158,12 @@ class AuthorizedSqlStore:
|
||||||
async def fetch_one(
|
async def fetch_one(
|
||||||
self,
|
self,
|
||||||
table: str,
|
table: str,
|
||||||
policy: list[AccessRule],
|
|
||||||
where: Mapping[str, Any] | None = None,
|
where: Mapping[str, Any] | None = None,
|
||||||
order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None,
|
order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None,
|
||||||
) -> dict[str, Any] | None:
|
) -> dict[str, Any] | None:
|
||||||
"""Fetch one row with automatic access control checking."""
|
"""Fetch one row with automatic access control checking."""
|
||||||
results = await self.fetch_all(
|
results = await self.fetch_all(
|
||||||
table=table,
|
table=table,
|
||||||
policy=policy,
|
|
||||||
where=where,
|
where=where,
|
||||||
limit=1,
|
limit=1,
|
||||||
order_by=order_by,
|
order_by=order_by,
|
||||||
|
|
|
@ -57,7 +57,7 @@ def authorized_store(backend_config):
|
||||||
config = config_func()
|
config = config_func()
|
||||||
|
|
||||||
base_sqlstore = sqlstore_impl(config)
|
base_sqlstore = sqlstore_impl(config)
|
||||||
authorized_store = AuthorizedSqlStore(base_sqlstore)
|
authorized_store = AuthorizedSqlStore(base_sqlstore, default_policy())
|
||||||
|
|
||||||
yield authorized_store
|
yield authorized_store
|
||||||
|
|
||||||
|
@ -106,7 +106,7 @@ async def test_authorized_store_attributes(mock_get_authenticated_user, authoriz
|
||||||
await authorized_store.insert(table_name, {"id": "1", "data": "public_data"})
|
await authorized_store.insert(table_name, {"id": "1", "data": "public_data"})
|
||||||
|
|
||||||
# Test fetching with no user - should not error on JSON comparison
|
# Test fetching with no user - should not error on JSON comparison
|
||||||
result = await authorized_store.fetch_all(table_name, policy=default_policy())
|
result = await authorized_store.fetch_all(table_name)
|
||||||
assert len(result.data) == 1
|
assert len(result.data) == 1
|
||||||
assert result.data[0]["id"] == "1"
|
assert result.data[0]["id"] == "1"
|
||||||
assert result.data[0]["access_attributes"] is None
|
assert result.data[0]["access_attributes"] is None
|
||||||
|
@ -119,7 +119,7 @@ async def test_authorized_store_attributes(mock_get_authenticated_user, authoriz
|
||||||
await authorized_store.insert(table_name, {"id": "2", "data": "admin_data"})
|
await authorized_store.insert(table_name, {"id": "2", "data": "admin_data"})
|
||||||
|
|
||||||
# Fetch all - admin should see both
|
# Fetch all - admin should see both
|
||||||
result = await authorized_store.fetch_all(table_name, policy=default_policy())
|
result = await authorized_store.fetch_all(table_name)
|
||||||
assert len(result.data) == 2
|
assert len(result.data) == 2
|
||||||
|
|
||||||
# Test with non-admin user
|
# Test with non-admin user
|
||||||
|
@ -127,7 +127,7 @@ async def test_authorized_store_attributes(mock_get_authenticated_user, authoriz
|
||||||
mock_get_authenticated_user.return_value = regular_user
|
mock_get_authenticated_user.return_value = regular_user
|
||||||
|
|
||||||
# Should only see public record
|
# Should only see public record
|
||||||
result = await authorized_store.fetch_all(table_name, policy=default_policy())
|
result = await authorized_store.fetch_all(table_name)
|
||||||
assert len(result.data) == 1
|
assert len(result.data) == 1
|
||||||
assert result.data[0]["id"] == "1"
|
assert result.data[0]["id"] == "1"
|
||||||
|
|
||||||
|
@ -156,7 +156,7 @@ async def test_authorized_store_attributes(mock_get_authenticated_user, authoriz
|
||||||
|
|
||||||
# Now test with the multi-user who has both roles=admin and teams=dev
|
# Now test with the multi-user who has both roles=admin and teams=dev
|
||||||
mock_get_authenticated_user.return_value = multi_user
|
mock_get_authenticated_user.return_value = multi_user
|
||||||
result = await authorized_store.fetch_all(table_name, policy=default_policy())
|
result = await authorized_store.fetch_all(table_name)
|
||||||
|
|
||||||
# Should see:
|
# Should see:
|
||||||
# - public record (1) - no access_attributes
|
# - public record (1) - no access_attributes
|
||||||
|
@ -217,21 +217,24 @@ async def test_user_ownership_policy(mock_get_authenticated_user, authorized_sto
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# Create a new authorized store with the owner-only policy
|
||||||
|
owner_only_store = AuthorizedSqlStore(authorized_store.sql_store, owner_only_policy)
|
||||||
|
|
||||||
# Test user1 access - should only see their own record
|
# Test user1 access - should only see their own record
|
||||||
mock_get_authenticated_user.return_value = user1
|
mock_get_authenticated_user.return_value = user1
|
||||||
result = await authorized_store.fetch_all(table_name, policy=owner_only_policy)
|
result = await owner_only_store.fetch_all(table_name)
|
||||||
assert len(result.data) == 1, f"Expected user1 to see 1 record, got {len(result.data)}"
|
assert len(result.data) == 1, f"Expected user1 to see 1 record, got {len(result.data)}"
|
||||||
assert result.data[0]["id"] == "1", f"Expected user1's record, got {result.data[0]['id']}"
|
assert result.data[0]["id"] == "1", f"Expected user1's record, got {result.data[0]['id']}"
|
||||||
|
|
||||||
# Test user2 access - should only see their own record
|
# Test user2 access - should only see their own record
|
||||||
mock_get_authenticated_user.return_value = user2
|
mock_get_authenticated_user.return_value = user2
|
||||||
result = await authorized_store.fetch_all(table_name, policy=owner_only_policy)
|
result = await owner_only_store.fetch_all(table_name)
|
||||||
assert len(result.data) == 1, f"Expected user2 to see 1 record, got {len(result.data)}"
|
assert len(result.data) == 1, f"Expected user2 to see 1 record, got {len(result.data)}"
|
||||||
assert result.data[0]["id"] == "2", f"Expected user2's record, got {result.data[0]['id']}"
|
assert result.data[0]["id"] == "2", f"Expected user2's record, got {result.data[0]['id']}"
|
||||||
|
|
||||||
# Test with anonymous user - should see no records
|
# Test with anonymous user - should see no records
|
||||||
mock_get_authenticated_user.return_value = None
|
mock_get_authenticated_user.return_value = None
|
||||||
result = await authorized_store.fetch_all(table_name, policy=owner_only_policy)
|
result = await owner_only_store.fetch_all(table_name)
|
||||||
assert len(result.data) == 0, f"Expected anonymous user to see 0 records, got {len(result.data)}"
|
assert len(result.data) == 0, f"Expected anonymous user to see 0 records, got {len(result.data)}"
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
|
|
|
@ -26,7 +26,7 @@ async def test_authorized_fetch_with_where_sql_access_control(mock_get_authentic
|
||||||
db_path=tmp_dir + "/" + db_name,
|
db_path=tmp_dir + "/" + db_name,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
sqlstore = AuthorizedSqlStore(base_sqlstore)
|
sqlstore = AuthorizedSqlStore(base_sqlstore, default_policy())
|
||||||
|
|
||||||
# Create table with access control
|
# Create table with access control
|
||||||
await sqlstore.create_table(
|
await sqlstore.create_table(
|
||||||
|
@ -56,24 +56,24 @@ async def test_authorized_fetch_with_where_sql_access_control(mock_get_authentic
|
||||||
mock_get_authenticated_user.return_value = admin_user
|
mock_get_authenticated_user.return_value = admin_user
|
||||||
|
|
||||||
# Admin should see both documents
|
# Admin should see both documents
|
||||||
result = await sqlstore.fetch_all("documents", policy=default_policy(), where={"id": 1})
|
result = await sqlstore.fetch_all("documents", where={"id": 1})
|
||||||
assert len(result.data) == 1
|
assert len(result.data) == 1
|
||||||
assert result.data[0]["title"] == "Admin Document"
|
assert result.data[0]["title"] == "Admin Document"
|
||||||
|
|
||||||
# User should only see their document
|
# User should only see their document
|
||||||
mock_get_authenticated_user.return_value = regular_user
|
mock_get_authenticated_user.return_value = regular_user
|
||||||
|
|
||||||
result = await sqlstore.fetch_all("documents", policy=default_policy(), where={"id": 1})
|
result = await sqlstore.fetch_all("documents", where={"id": 1})
|
||||||
assert len(result.data) == 0
|
assert len(result.data) == 0
|
||||||
|
|
||||||
result = await sqlstore.fetch_all("documents", policy=default_policy(), where={"id": 2})
|
result = await sqlstore.fetch_all("documents", where={"id": 2})
|
||||||
assert len(result.data) == 1
|
assert len(result.data) == 1
|
||||||
assert result.data[0]["title"] == "User Document"
|
assert result.data[0]["title"] == "User Document"
|
||||||
|
|
||||||
row = await sqlstore.fetch_one("documents", policy=default_policy(), where={"id": 1})
|
row = await sqlstore.fetch_one("documents", where={"id": 1})
|
||||||
assert row is None
|
assert row is None
|
||||||
|
|
||||||
row = await sqlstore.fetch_one("documents", policy=default_policy(), where={"id": 2})
|
row = await sqlstore.fetch_one("documents", where={"id": 2})
|
||||||
assert row is not None
|
assert row is not None
|
||||||
assert row["title"] == "User Document"
|
assert row["title"] == "User Document"
|
||||||
|
|
||||||
|
@ -88,7 +88,7 @@ async def test_sql_policy_consistency(mock_get_authenticated_user):
|
||||||
db_path=tmp_dir + "/" + db_name,
|
db_path=tmp_dir + "/" + db_name,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
sqlstore = AuthorizedSqlStore(base_sqlstore)
|
sqlstore = AuthorizedSqlStore(base_sqlstore, default_policy())
|
||||||
|
|
||||||
await sqlstore.create_table(
|
await sqlstore.create_table(
|
||||||
table="resources",
|
table="resources",
|
||||||
|
@ -144,7 +144,7 @@ async def test_sql_policy_consistency(mock_get_authenticated_user):
|
||||||
user = User(principal=user_data["principal"], attributes=user_data["attributes"])
|
user = User(principal=user_data["principal"], attributes=user_data["attributes"])
|
||||||
mock_get_authenticated_user.return_value = user
|
mock_get_authenticated_user.return_value = user
|
||||||
|
|
||||||
sql_results = await sqlstore.fetch_all("resources", policy=policy)
|
sql_results = await sqlstore.fetch_all("resources")
|
||||||
sql_ids = {row["id"] for row in sql_results.data}
|
sql_ids = {row["id"] for row in sql_results.data}
|
||||||
policy_ids = set()
|
policy_ids = set()
|
||||||
for scenario in test_scenarios:
|
for scenario in test_scenarios:
|
||||||
|
@ -174,7 +174,7 @@ async def test_authorized_store_user_attribute_capture(mock_get_authenticated_us
|
||||||
db_path=tmp_dir + "/" + db_name,
|
db_path=tmp_dir + "/" + db_name,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
authorized_store = AuthorizedSqlStore(base_sqlstore)
|
authorized_store = AuthorizedSqlStore(base_sqlstore, default_policy())
|
||||||
|
|
||||||
await authorized_store.create_table(
|
await authorized_store.create_table(
|
||||||
table="user_data",
|
table="user_data",
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue