fix: authorized sql store with postgres (#2641)
Some checks failed
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 1s
Integration Tests / test-matrix (library, 3.12, datasets) (push) Failing after 4s
Integration Tests / test-matrix (library, 3.12, agents) (push) Failing after 7s
Integration Tests / test-matrix (library, 3.12, scoring) (push) Failing after 5s
Integration Tests / test-matrix (library, 3.13, post_training) (push) Failing after 4s
Integration Tests / test-matrix (library, 3.12, inference) (push) Failing after 10s
Integration Tests / test-matrix (library, 3.12, vector_io) (push) Failing after 9s
Integration Tests / test-matrix (library, 3.13, providers) (push) Failing after 8s
Integration Tests / test-matrix (library, 3.12, tool_runtime) (push) Failing after 6s
Integration Tests / test-matrix (library, 3.12, post_training) (push) Failing after 13s
SqlStore Integration Tests / test-postgres (3.13) (push) Failing after 16s
SqlStore Integration Tests / test-postgres (3.12) (push) Failing after 19s
Integration Tests / test-matrix (library, 3.13, datasets) (push) Failing after 15s
Integration Tests / test-matrix (server, 3.12, datasets) (push) Failing after 8s
Integration Tests / test-matrix (library, 3.12, inspect) (push) Failing after 12s
Integration Tests / test-matrix (server, 3.13, inference) (push) Failing after 9s
Integration Tests / test-matrix (library, 3.13, tool_runtime) (push) Failing after 13s
Integration Tests / test-matrix (library, 3.13, scoring) (push) Failing after 12s
Integration Tests / test-matrix (server, 3.12, providers) (push) Failing after 8s
Integration Tests / test-matrix (library, 3.13, vector_io) (push) Failing after 10s
Integration Tests / test-matrix (server, 3.12, agents) (push) Failing after 12s
Integration Tests / test-matrix (library, 3.13, inspect) (push) Failing after 14s
Integration Tests / test-matrix (server, 3.12, tool_runtime) (push) Failing after 10s
Integration Tests / test-matrix (server, 3.13, datasets) (push) Failing after 8s
Integration Tests / test-matrix (server, 3.13, inspect) (push) Failing after 12s
Integration Tests / test-matrix (server, 3.12, inference) (push) Failing after 6s
Integration Tests / test-matrix (server, 3.13, providers) (push) Failing after 10s
Integration Tests / test-matrix (server, 3.13, scoring) (push) Failing after 9s
Integration Tests / test-matrix (library, 3.13, inference) (push) Failing after 11s
Integration Tests / test-matrix (server, 3.12, inspect) (push) Failing after 12s
Integration Tests / test-matrix (server, 3.12, scoring) (push) Failing after 11s
Integration Tests / test-matrix (library, 3.12, providers) (push) Failing after 13s
Integration Tests / test-matrix (server, 3.13, agents) (push) Failing after 10s
Vector IO Integration Tests / test-matrix (3.12, inline::faiss) (push) Failing after 8s
Integration Tests / test-matrix (server, 3.13, post_training) (push) Failing after 11s
Integration Tests / test-matrix (server, 3.13, vector_io) (push) Failing after 7s
Vector IO Integration Tests / test-matrix (3.12, inline::milvus) (push) Failing after 8s
Integration Tests / test-matrix (library, 3.13, agents) (push) Failing after 13s
Integration Tests / test-matrix (server, 3.12, vector_io) (push) Failing after 14s
Integration Tests / test-matrix (server, 3.12, post_training) (push) Failing after 14s
Integration Tests / test-matrix (server, 3.13, tool_runtime) (push) Failing after 8s
Vector IO Integration Tests / test-matrix (3.12, inline::sqlite-vec) (push) Failing after 25s
Vector IO Integration Tests / test-matrix (3.12, remote::chromadb) (push) Failing after 23s
Vector IO Integration Tests / test-matrix (3.12, remote::pgvector) (push) Failing after 28s
Vector IO Integration Tests / test-matrix (3.13, inline::faiss) (push) Failing after 27s
Vector IO Integration Tests / test-matrix (3.13, inline::milvus) (push) Failing after 12s
Vector IO Integration Tests / test-matrix (3.13, inline::sqlite-vec) (push) Failing after 10s
Vector IO Integration Tests / test-matrix (3.13, remote::chromadb) (push) Failing after 6s
Vector IO Integration Tests / test-matrix (3.13, remote::pgvector) (push) Failing after 5s
Test Llama Stack Build / generate-matrix (push) Successful in 5s
Python Package Build Test / build (3.12) (push) Failing after 1s
Test External Providers / test-external-providers (venv) (push) Failing after 3s
Python Package Build Test / build (3.13) (push) Failing after 3s
Update ReadTheDocs / update-readthedocs (push) Failing after 3s
Test Llama Stack Build / build (push) Failing after 4s
Unit Tests / unit-tests (3.12) (push) Failing after 4s
Unit Tests / unit-tests (3.13) (push) Failing after 7s
Test Llama Stack Build / build-single-provider (push) Failing after 44s
Test Llama Stack Build / build-ubi9-container-distribution (push) Failing after 41s
Test Llama Stack Build / build-custom-container-distribution (push) Failing after 43s
Pre-commit / pre-commit (push) Successful in 1m34s

# What does this PR do?
postgres has different json extract syntax from sqlite

## Test Plan
added integration test
This commit is contained in:
ehhuang 2025-07-07 19:36:34 -07:00 committed by GitHub
parent 5bb3817c49
commit e9926564bd
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 337 additions and 27 deletions

View file

@ -0,0 +1,70 @@
name: SqlStore Integration Tests
on:
push:
branches: [ main ]
pull_request:
branches: [ main ]
paths:
- 'llama_stack/providers/utils/sqlstore/**'
- 'tests/integration/sqlstore/**'
- 'uv.lock'
- 'pyproject.toml'
- 'requirements.txt'
- '.github/workflows/integration-sql-store-tests.yml' # This workflow
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true
jobs:
test-postgres:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.12", "3.13"]
fail-fast: false
services:
postgres:
image: postgres:15
env:
POSTGRES_USER: llamastack
POSTGRES_PASSWORD: llamastack
POSTGRES_DB: llamastack
ports:
- 5432:5432
options: >-
--health-cmd pg_isready
--health-interval 10s
--health-timeout 5s
--health-retries 5
steps:
- name: Checkout repository
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Install dependencies
uses: ./.github/actions/setup-runner
with:
python-version: ${{ matrix.python-version }}
- name: Run SqlStore Integration Tests
env:
ENABLE_POSTGRES_TESTS: "true"
POSTGRES_HOST: localhost
POSTGRES_PORT: 5432
POSTGRES_DB: llamastack
POSTGRES_USER: llamastack
POSTGRES_PASSWORD: llamastack
run: |
uv run pytest -sv tests/integration/providers/utils/sqlstore/
- name: Upload test logs
if: ${{ always() }}
uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2
with:
name: postgres-test-logs-${{ github.run_id }}-${{ github.run_attempt }}-${{ matrix.python-version }}
path: |
*.log
retention-days: 1

View file

@ -15,6 +15,7 @@ from llama_stack.distribution.request_headers import get_authenticated_user
from llama_stack.log import get_logger
from .api import ColumnDefinition, ColumnType, PaginatedResponse, SqlStore
from .sqlstore import SqlStoreType
logger = get_logger(name=__name__, category="authorized_sqlstore")
@ -71,9 +72,18 @@ class AuthorizedSqlStore:
:param sql_store: Base SqlStore implementation to wrap
"""
self.sql_store = sql_store
self._detect_database_type()
self._validate_sql_optimized_policy()
def _detect_database_type(self) -> None:
"""Detect the database type from the underlying SQL store."""
if not hasattr(self.sql_store, "config"):
raise ValueError("SqlStore must have a config attribute to be used with AuthorizedSqlStore")
self.database_type = self.sql_store.config.type
if self.database_type not in [SqlStoreType.postgres, SqlStoreType.sqlite]:
raise ValueError(f"Unsupported database type: {self.database_type}")
def _validate_sql_optimized_policy(self) -> None:
"""Validate that SQL_OPTIMIZED_POLICY matches the actual default_policy().
@ -181,6 +191,50 @@ class AuthorizedSqlStore:
else:
return self._build_conservative_where_clause()
def _json_extract(self, column: str, path: str) -> str:
"""Extract JSON value (keeping JSON type).
Args:
column: The JSON column name
path: The JSON path (e.g., 'roles', 'teams')
Returns:
SQL expression to extract JSON value
"""
if self.database_type == SqlStoreType.postgres:
return f"{column}->'{path}'"
elif self.database_type == SqlStoreType.sqlite:
return f"JSON_EXTRACT({column}, '$.{path}')"
else:
raise ValueError(f"Unsupported database type: {self.database_type}")
def _json_extract_text(self, column: str, path: str) -> str:
"""Extract JSON value as text.
Args:
column: The JSON column name
path: The JSON path (e.g., 'roles', 'teams')
Returns:
SQL expression to extract JSON value as text
"""
if self.database_type == SqlStoreType.postgres:
return f"{column}->>'{path}'"
elif self.database_type == SqlStoreType.sqlite:
return f"JSON_EXTRACT({column}, '$.{path}')"
else:
raise ValueError(f"Unsupported database type: {self.database_type}")
def _get_public_access_conditions(self) -> list[str]:
"""Get the SQL conditions for public access."""
if self.database_type == SqlStoreType.postgres:
# Postgres stores JSON null as 'null'
return ["access_attributes::text = 'null'"]
elif self.database_type == SqlStoreType.sqlite:
return ["access_attributes = 'null'"]
else:
raise ValueError(f"Unsupported database type: {self.database_type}")
def _build_default_policy_where_clause(self) -> str:
"""Build SQL WHERE clause for the default policy.
@ -189,29 +243,32 @@ class AuthorizedSqlStore:
"""
current_user = get_authenticated_user()
base_conditions = self._get_public_access_conditions()
if not current_user or not current_user.attributes:
return "(access_attributes IS NULL OR access_attributes = 'null' OR access_attributes = '{}')"
# Only allow public records
return f"({' OR '.join(base_conditions)})"
else:
base_conditions = ["access_attributes IS NULL", "access_attributes = 'null'", "access_attributes = '{}'"]
user_attr_conditions = []
for attr_key, user_values in current_user.attributes.items():
if user_values:
value_conditions = []
for value in user_values:
value_conditions.append(f"JSON_EXTRACT(access_attributes, '$.{attr_key}') LIKE '%\"{value}\"%'")
# Check if JSON array contains the value
escaped_value = value.replace("'", "''")
json_text = self._json_extract_text("access_attributes", attr_key)
value_conditions.append(f"({json_text} LIKE '%\"{escaped_value}\"%')")
if value_conditions:
category_missing = f"JSON_EXTRACT(access_attributes, '$.{attr_key}') IS NULL"
# Check if the category is missing (NULL)
category_missing = f"{self._json_extract('access_attributes', attr_key)} IS NULL"
user_matches_category = f"({' OR '.join(value_conditions)})"
user_attr_conditions.append(f"({category_missing} OR {user_matches_category})")
if user_attr_conditions:
all_requirements_met = f"({' AND '.join(user_attr_conditions)})"
base_conditions.append(all_requirements_met)
return f"({' OR '.join(base_conditions)})"
else:
return f"({' OR '.join(base_conditions)})"
def _build_conservative_where_clause(self) -> str:
@ -222,5 +279,8 @@ class AuthorizedSqlStore:
current_user = get_authenticated_user()
if not current_user:
return "(access_attributes IS NULL OR access_attributes = 'null' OR access_attributes = '{}')"
# Only allow public records
base_conditions = self._get_public_access_conditions()
return f"({' OR '.join(base_conditions)})"
return "1=1"

View file

@ -4,9 +4,8 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from abc import abstractmethod
from enum import Enum
from enum import StrEnum
from pathlib import Path
from typing import Annotated, Literal
@ -19,7 +18,7 @@ from .api import SqlStore
sql_store_pip_packages = ["sqlalchemy[asyncio]", "aiosqlite", "asyncpg"]
class SqlStoreType(Enum):
class SqlStoreType(StrEnum):
sqlite = "sqlite"
postgres = "postgres"
@ -36,7 +35,7 @@ class SqlAlchemySqlStoreConfig(BaseModel):
class SqliteSqlStoreConfig(SqlAlchemySqlStoreConfig):
type: Literal["sqlite"] = SqlStoreType.sqlite.value
type: Literal[SqlStoreType.sqlite] = SqlStoreType.sqlite
db_path: str = Field(
default=(RUNTIME_BASE_DIR / "sqlstore.db").as_posix(),
description="Database path, e.g. ~/.llama/distributions/ollama/sqlstore.db",
@ -59,7 +58,7 @@ class SqliteSqlStoreConfig(SqlAlchemySqlStoreConfig):
class PostgresSqlStoreConfig(SqlAlchemySqlStoreConfig):
type: Literal["postgres"] = SqlStoreType.postgres.value
type: Literal[SqlStoreType.postgres] = SqlStoreType.postgres
host: str = "localhost"
port: int = 5432
db: str = "llamastack"
@ -107,7 +106,7 @@ def get_pip_packages(store_config: dict | SqlStoreConfig) -> list[str]:
def sqlstore_impl(config: SqlStoreConfig) -> SqlStore:
if config.type in [SqlStoreType.sqlite.value, SqlStoreType.postgres.value]:
if config.type in [SqlStoreType.sqlite, SqlStoreType.postgres]:
from .sqlalchemy_sqlstore import SqlAlchemySqlStoreImpl
impl = SqlAlchemySqlStoreImpl(config)

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,173 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
import tempfile
from unittest.mock import patch
import pytest
from llama_stack.distribution.access_control.access_control import default_policy
from llama_stack.distribution.datatypes import User
from llama_stack.providers.utils.sqlstore.api import ColumnType
from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore
from llama_stack.providers.utils.sqlstore.sqlalchemy_sqlstore import SqlAlchemySqlStoreImpl
from llama_stack.providers.utils.sqlstore.sqlstore import PostgresSqlStoreConfig, SqliteSqlStoreConfig
def get_postgres_config():
"""Get PostgreSQL configuration if tests are enabled."""
return PostgresSqlStoreConfig(
host=os.environ.get("POSTGRES_HOST", "localhost"),
port=int(os.environ.get("POSTGRES_PORT", "5432")),
db=os.environ.get("POSTGRES_DB", "llamastack"),
user=os.environ.get("POSTGRES_USER", "llamastack"),
password=os.environ.get("POSTGRES_PASSWORD", "llamastack"),
)
def get_sqlite_config():
"""Get SQLite configuration with temporary database."""
tmp_file = tempfile.NamedTemporaryFile(suffix=".db", delete=False)
tmp_file.close()
return SqliteSqlStoreConfig(db_path=tmp_file.name), tmp_file.name
@pytest.mark.asyncio
@pytest.mark.parametrize(
"backend_config",
[
pytest.param(
("postgres", get_postgres_config),
marks=pytest.mark.skipif(
not os.environ.get("ENABLE_POSTGRES_TESTS"),
reason="PostgreSQL tests require ENABLE_POSTGRES_TESTS environment variable",
),
id="postgres",
),
pytest.param(("sqlite", get_sqlite_config), id="sqlite"),
],
)
@patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user")
async def test_json_comparison(mock_get_authenticated_user, backend_config):
"""Test that JSON column comparisons work correctly for both PostgreSQL and SQLite"""
backend_name, config_func = backend_config
# Handle different config types
if backend_name == "postgres":
config = config_func()
cleanup_path = None
else: # sqlite
config, cleanup_path = config_func()
try:
base_sqlstore = SqlAlchemySqlStoreImpl(config)
authorized_store = AuthorizedSqlStore(base_sqlstore)
# Create test table
table_name = f"test_json_comparison_{backend_name}"
await authorized_store.create_table(
table=table_name,
schema={
"id": ColumnType.STRING,
"data": ColumnType.STRING,
},
)
try:
# Test with no authenticated user (should handle JSON null comparison)
mock_get_authenticated_user.return_value = None
# Insert some test data
await authorized_store.insert(table_name, {"id": "1", "data": "public_data"})
# Test fetching with no user - should not error on JSON comparison
result = await authorized_store.fetch_all(table_name, policy=default_policy())
assert len(result.data) == 1
assert result.data[0]["id"] == "1"
assert result.data[0]["access_attributes"] is None
# Test with authenticated user
test_user = User("test-user", {"roles": ["admin"]})
mock_get_authenticated_user.return_value = test_user
# Insert data with user attributes
await authorized_store.insert(table_name, {"id": "2", "data": "admin_data"})
# Fetch all - admin should see both
result = await authorized_store.fetch_all(table_name, policy=default_policy())
assert len(result.data) == 2
# Test with non-admin user
regular_user = User("regular-user", {"roles": ["user"]})
mock_get_authenticated_user.return_value = regular_user
# Should only see public record
result = await authorized_store.fetch_all(table_name, policy=default_policy())
assert len(result.data) == 1
assert result.data[0]["id"] == "1"
# Test the category missing branch: user with multiple attributes
multi_user = User("multi-user", {"roles": ["admin"], "teams": ["dev"]})
mock_get_authenticated_user.return_value = multi_user
# Insert record with multi-user (has both roles and teams)
await authorized_store.insert(table_name, {"id": "3", "data": "multi_user_data"})
# Test different user types to create records with different attribute patterns
# Record with only roles (teams category will be missing)
roles_only_user = User("roles-user", {"roles": ["admin"]})
mock_get_authenticated_user.return_value = roles_only_user
await authorized_store.insert(table_name, {"id": "4", "data": "roles_only_data"})
# Record with only teams (roles category will be missing)
teams_only_user = User("teams-user", {"teams": ["dev"]})
mock_get_authenticated_user.return_value = teams_only_user
await authorized_store.insert(table_name, {"id": "5", "data": "teams_only_data"})
# Record with different roles/teams (shouldn't match our test user)
different_user = User("different-user", {"roles": ["user"], "teams": ["qa"]})
mock_get_authenticated_user.return_value = different_user
await authorized_store.insert(table_name, {"id": "6", "data": "different_user_data"})
# Now test with the multi-user who has both roles=admin and teams=dev
mock_get_authenticated_user.return_value = multi_user
result = await authorized_store.fetch_all(table_name, policy=default_policy())
# Should see:
# - public record (1) - no access_attributes
# - admin record (2) - user matches roles=admin, teams missing (allowed)
# - multi_user record (3) - user matches both roles=admin and teams=dev
# - roles_only record (4) - user matches roles=admin, teams missing (allowed)
# - teams_only record (5) - user matches teams=dev, roles missing (allowed)
# Should NOT see:
# - different_user record (6) - user doesn't match roles=user or teams=qa
expected_ids = {"1", "2", "3", "4", "5"}
actual_ids = {record["id"] for record in result.data}
assert actual_ids == expected_ids, f"Expected to see records {expected_ids} but got {actual_ids}"
# Verify the category missing logic specifically
# Records 4 and 5 test the "category missing" branch where one attribute category is missing
category_test_ids = {record["id"] for record in result.data if record["id"] in ["4", "5"]}
assert category_test_ids == {"4", "5"}, (
f"Category missing logic failed: expected 4,5 but got {category_test_ids}"
)
finally:
# Clean up records
for record_id in ["1", "2", "3", "4", "5", "6"]:
try:
await base_sqlstore.delete(table_name, {"id": record_id})
except Exception:
pass
finally:
# Clean up temporary SQLite database file if needed
if cleanup_path:
try:
os.unlink(cleanup_path)
except OSError:
pass

View file

@ -104,19 +104,17 @@ async def test_sql_policy_consistency(mock_get_authenticated_user):
# Test scenarios with different access control patterns
test_scenarios = [
# Scenario 1: Public record (no access control)
# Scenario 1: Public record (no access control - represents None user insert)
{"id": "1", "name": "public", "access_attributes": None},
# Scenario 2: Empty access control (should be treated as public)
{"id": "2", "name": "empty", "access_attributes": {}},
# Scenario 3: Record with roles requirement
{"id": "3", "name": "admin-only", "access_attributes": {"roles": ["admin"]}},
# Scenario 4: Record with multiple attribute categories
{"id": "4", "name": "admin-ml-team", "access_attributes": {"roles": ["admin"], "teams": ["ml-team"]}},
# Scenario 5: Record with teams only (missing roles category)
{"id": "5", "name": "ml-team-only", "access_attributes": {"teams": ["ml-team"]}},
# Scenario 6: Record with roles and projects
# Scenario 2: Record with roles requirement
{"id": "2", "name": "admin-only", "access_attributes": {"roles": ["admin"]}},
# Scenario 3: Record with multiple attribute categories
{"id": "3", "name": "admin-ml-team", "access_attributes": {"roles": ["admin"], "teams": ["ml-team"]}},
# Scenario 4: Record with teams only (missing roles category)
{"id": "4", "name": "ml-team-only", "access_attributes": {"teams": ["ml-team"]}},
# Scenario 5: Record with roles and projects
{
"id": "6",
"id": "5",
"name": "admin-project-x",
"access_attributes": {"roles": ["admin"], "projects": ["project-x"]},
},