fix: authorized sql store with postgres (#2641)
Some checks failed
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 1s
Integration Tests / test-matrix (library, 3.12, datasets) (push) Failing after 4s
Integration Tests / test-matrix (library, 3.12, agents) (push) Failing after 7s
Integration Tests / test-matrix (library, 3.12, scoring) (push) Failing after 5s
Integration Tests / test-matrix (library, 3.13, post_training) (push) Failing after 4s
Integration Tests / test-matrix (library, 3.12, inference) (push) Failing after 10s
Integration Tests / test-matrix (library, 3.12, vector_io) (push) Failing after 9s
Integration Tests / test-matrix (library, 3.13, providers) (push) Failing after 8s
Integration Tests / test-matrix (library, 3.12, tool_runtime) (push) Failing after 6s
Integration Tests / test-matrix (library, 3.12, post_training) (push) Failing after 13s
SqlStore Integration Tests / test-postgres (3.13) (push) Failing after 16s
SqlStore Integration Tests / test-postgres (3.12) (push) Failing after 19s
Integration Tests / test-matrix (library, 3.13, datasets) (push) Failing after 15s
Integration Tests / test-matrix (server, 3.12, datasets) (push) Failing after 8s
Integration Tests / test-matrix (library, 3.12, inspect) (push) Failing after 12s
Integration Tests / test-matrix (server, 3.13, inference) (push) Failing after 9s
Integration Tests / test-matrix (library, 3.13, tool_runtime) (push) Failing after 13s
Integration Tests / test-matrix (library, 3.13, scoring) (push) Failing after 12s
Integration Tests / test-matrix (server, 3.12, providers) (push) Failing after 8s
Integration Tests / test-matrix (library, 3.13, vector_io) (push) Failing after 10s
Integration Tests / test-matrix (server, 3.12, agents) (push) Failing after 12s
Integration Tests / test-matrix (library, 3.13, inspect) (push) Failing after 14s
Integration Tests / test-matrix (server, 3.12, tool_runtime) (push) Failing after 10s
Integration Tests / test-matrix (server, 3.13, datasets) (push) Failing after 8s
Integration Tests / test-matrix (server, 3.13, inspect) (push) Failing after 12s
Integration Tests / test-matrix (server, 3.12, inference) (push) Failing after 6s
Integration Tests / test-matrix (server, 3.13, providers) (push) Failing after 10s
Integration Tests / test-matrix (server, 3.13, scoring) (push) Failing after 9s
Integration Tests / test-matrix (library, 3.13, inference) (push) Failing after 11s
Integration Tests / test-matrix (server, 3.12, inspect) (push) Failing after 12s
Integration Tests / test-matrix (server, 3.12, scoring) (push) Failing after 11s
Integration Tests / test-matrix (library, 3.12, providers) (push) Failing after 13s
Integration Tests / test-matrix (server, 3.13, agents) (push) Failing after 10s
Vector IO Integration Tests / test-matrix (3.12, inline::faiss) (push) Failing after 8s
Integration Tests / test-matrix (server, 3.13, post_training) (push) Failing after 11s
Integration Tests / test-matrix (server, 3.13, vector_io) (push) Failing after 7s
Vector IO Integration Tests / test-matrix (3.12, inline::milvus) (push) Failing after 8s
Integration Tests / test-matrix (library, 3.13, agents) (push) Failing after 13s
Integration Tests / test-matrix (server, 3.12, vector_io) (push) Failing after 14s
Integration Tests / test-matrix (server, 3.12, post_training) (push) Failing after 14s
Integration Tests / test-matrix (server, 3.13, tool_runtime) (push) Failing after 8s
Vector IO Integration Tests / test-matrix (3.12, inline::sqlite-vec) (push) Failing after 25s
Vector IO Integration Tests / test-matrix (3.12, remote::chromadb) (push) Failing after 23s
Vector IO Integration Tests / test-matrix (3.12, remote::pgvector) (push) Failing after 28s
Vector IO Integration Tests / test-matrix (3.13, inline::faiss) (push) Failing after 27s
Vector IO Integration Tests / test-matrix (3.13, inline::milvus) (push) Failing after 12s
Vector IO Integration Tests / test-matrix (3.13, inline::sqlite-vec) (push) Failing after 10s
Vector IO Integration Tests / test-matrix (3.13, remote::chromadb) (push) Failing after 6s
Vector IO Integration Tests / test-matrix (3.13, remote::pgvector) (push) Failing after 5s
Test Llama Stack Build / generate-matrix (push) Successful in 5s
Python Package Build Test / build (3.12) (push) Failing after 1s
Test External Providers / test-external-providers (venv) (push) Failing after 3s
Python Package Build Test / build (3.13) (push) Failing after 3s
Update ReadTheDocs / update-readthedocs (push) Failing after 3s
Test Llama Stack Build / build (push) Failing after 4s
Unit Tests / unit-tests (3.12) (push) Failing after 4s
Unit Tests / unit-tests (3.13) (push) Failing after 7s
Test Llama Stack Build / build-single-provider (push) Failing after 44s
Test Llama Stack Build / build-ubi9-container-distribution (push) Failing after 41s
Test Llama Stack Build / build-custom-container-distribution (push) Failing after 43s
Pre-commit / pre-commit (push) Successful in 1m34s

# What does this PR do?
postgres has different json extract syntax from sqlite

## Test Plan
added integration test
This commit is contained in:
ehhuang 2025-07-07 19:36:34 -07:00 committed by GitHub
parent 5bb3817c49
commit e9926564bd
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 337 additions and 27 deletions

View file

@ -0,0 +1,70 @@
name: SqlStore Integration Tests
on:
push:
branches: [ main ]
pull_request:
branches: [ main ]
paths:
- 'llama_stack/providers/utils/sqlstore/**'
- 'tests/integration/sqlstore/**'
- 'uv.lock'
- 'pyproject.toml'
- 'requirements.txt'
- '.github/workflows/integration-sql-store-tests.yml' # This workflow
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true
jobs:
test-postgres:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.12", "3.13"]
fail-fast: false
services:
postgres:
image: postgres:15
env:
POSTGRES_USER: llamastack
POSTGRES_PASSWORD: llamastack
POSTGRES_DB: llamastack
ports:
- 5432:5432
options: >-
--health-cmd pg_isready
--health-interval 10s
--health-timeout 5s
--health-retries 5
steps:
- name: Checkout repository
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Install dependencies
uses: ./.github/actions/setup-runner
with:
python-version: ${{ matrix.python-version }}
- name: Run SqlStore Integration Tests
env:
ENABLE_POSTGRES_TESTS: "true"
POSTGRES_HOST: localhost
POSTGRES_PORT: 5432
POSTGRES_DB: llamastack
POSTGRES_USER: llamastack
POSTGRES_PASSWORD: llamastack
run: |
uv run pytest -sv tests/integration/providers/utils/sqlstore/
- name: Upload test logs
if: ${{ always() }}
uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2
with:
name: postgres-test-logs-${{ github.run_id }}-${{ github.run_attempt }}-${{ matrix.python-version }}
path: |
*.log
retention-days: 1

View file

@ -15,6 +15,7 @@ from llama_stack.distribution.request_headers import get_authenticated_user
from llama_stack.log import get_logger from llama_stack.log import get_logger
from .api import ColumnDefinition, ColumnType, PaginatedResponse, SqlStore from .api import ColumnDefinition, ColumnType, PaginatedResponse, SqlStore
from .sqlstore import SqlStoreType
logger = get_logger(name=__name__, category="authorized_sqlstore") logger = get_logger(name=__name__, category="authorized_sqlstore")
@ -71,9 +72,18 @@ class AuthorizedSqlStore:
:param sql_store: Base SqlStore implementation to wrap :param sql_store: Base SqlStore implementation to wrap
""" """
self.sql_store = sql_store self.sql_store = sql_store
self._detect_database_type()
self._validate_sql_optimized_policy() self._validate_sql_optimized_policy()
def _detect_database_type(self) -> None:
"""Detect the database type from the underlying SQL store."""
if not hasattr(self.sql_store, "config"):
raise ValueError("SqlStore must have a config attribute to be used with AuthorizedSqlStore")
self.database_type = self.sql_store.config.type
if self.database_type not in [SqlStoreType.postgres, SqlStoreType.sqlite]:
raise ValueError(f"Unsupported database type: {self.database_type}")
def _validate_sql_optimized_policy(self) -> None: def _validate_sql_optimized_policy(self) -> None:
"""Validate that SQL_OPTIMIZED_POLICY matches the actual default_policy(). """Validate that SQL_OPTIMIZED_POLICY matches the actual default_policy().
@ -181,6 +191,50 @@ class AuthorizedSqlStore:
else: else:
return self._build_conservative_where_clause() return self._build_conservative_where_clause()
def _json_extract(self, column: str, path: str) -> str:
"""Extract JSON value (keeping JSON type).
Args:
column: The JSON column name
path: The JSON path (e.g., 'roles', 'teams')
Returns:
SQL expression to extract JSON value
"""
if self.database_type == SqlStoreType.postgres:
return f"{column}->'{path}'"
elif self.database_type == SqlStoreType.sqlite:
return f"JSON_EXTRACT({column}, '$.{path}')"
else:
raise ValueError(f"Unsupported database type: {self.database_type}")
def _json_extract_text(self, column: str, path: str) -> str:
"""Extract JSON value as text.
Args:
column: The JSON column name
path: The JSON path (e.g., 'roles', 'teams')
Returns:
SQL expression to extract JSON value as text
"""
if self.database_type == SqlStoreType.postgres:
return f"{column}->>'{path}'"
elif self.database_type == SqlStoreType.sqlite:
return f"JSON_EXTRACT({column}, '$.{path}')"
else:
raise ValueError(f"Unsupported database type: {self.database_type}")
def _get_public_access_conditions(self) -> list[str]:
"""Get the SQL conditions for public access."""
if self.database_type == SqlStoreType.postgres:
# Postgres stores JSON null as 'null'
return ["access_attributes::text = 'null'"]
elif self.database_type == SqlStoreType.sqlite:
return ["access_attributes = 'null'"]
else:
raise ValueError(f"Unsupported database type: {self.database_type}")
def _build_default_policy_where_clause(self) -> str: def _build_default_policy_where_clause(self) -> str:
"""Build SQL WHERE clause for the default policy. """Build SQL WHERE clause for the default policy.
@ -189,29 +243,32 @@ class AuthorizedSqlStore:
""" """
current_user = get_authenticated_user() current_user = get_authenticated_user()
base_conditions = self._get_public_access_conditions()
if not current_user or not current_user.attributes: if not current_user or not current_user.attributes:
return "(access_attributes IS NULL OR access_attributes = 'null' OR access_attributes = '{}')" # Only allow public records
return f"({' OR '.join(base_conditions)})"
else: else:
base_conditions = ["access_attributes IS NULL", "access_attributes = 'null'", "access_attributes = '{}'"]
user_attr_conditions = [] user_attr_conditions = []
for attr_key, user_values in current_user.attributes.items(): for attr_key, user_values in current_user.attributes.items():
if user_values: if user_values:
value_conditions = [] value_conditions = []
for value in user_values: for value in user_values:
value_conditions.append(f"JSON_EXTRACT(access_attributes, '$.{attr_key}') LIKE '%\"{value}\"%'") # Check if JSON array contains the value
escaped_value = value.replace("'", "''")
json_text = self._json_extract_text("access_attributes", attr_key)
value_conditions.append(f"({json_text} LIKE '%\"{escaped_value}\"%')")
if value_conditions: if value_conditions:
category_missing = f"JSON_EXTRACT(access_attributes, '$.{attr_key}') IS NULL" # Check if the category is missing (NULL)
category_missing = f"{self._json_extract('access_attributes', attr_key)} IS NULL"
user_matches_category = f"({' OR '.join(value_conditions)})" user_matches_category = f"({' OR '.join(value_conditions)})"
user_attr_conditions.append(f"({category_missing} OR {user_matches_category})") user_attr_conditions.append(f"({category_missing} OR {user_matches_category})")
if user_attr_conditions: if user_attr_conditions:
all_requirements_met = f"({' AND '.join(user_attr_conditions)})" all_requirements_met = f"({' AND '.join(user_attr_conditions)})"
base_conditions.append(all_requirements_met) base_conditions.append(all_requirements_met)
return f"({' OR '.join(base_conditions)})"
else:
return f"({' OR '.join(base_conditions)})" return f"({' OR '.join(base_conditions)})"
def _build_conservative_where_clause(self) -> str: def _build_conservative_where_clause(self) -> str:
@ -222,5 +279,8 @@ class AuthorizedSqlStore:
current_user = get_authenticated_user() current_user = get_authenticated_user()
if not current_user: if not current_user:
return "(access_attributes IS NULL OR access_attributes = 'null' OR access_attributes = '{}')" # Only allow public records
base_conditions = self._get_public_access_conditions()
return f"({' OR '.join(base_conditions)})"
return "1=1" return "1=1"

View file

@ -4,9 +4,8 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from abc import abstractmethod from abc import abstractmethod
from enum import Enum from enum import StrEnum
from pathlib import Path from pathlib import Path
from typing import Annotated, Literal from typing import Annotated, Literal
@ -19,7 +18,7 @@ from .api import SqlStore
sql_store_pip_packages = ["sqlalchemy[asyncio]", "aiosqlite", "asyncpg"] sql_store_pip_packages = ["sqlalchemy[asyncio]", "aiosqlite", "asyncpg"]
class SqlStoreType(Enum): class SqlStoreType(StrEnum):
sqlite = "sqlite" sqlite = "sqlite"
postgres = "postgres" postgres = "postgres"
@ -36,7 +35,7 @@ class SqlAlchemySqlStoreConfig(BaseModel):
class SqliteSqlStoreConfig(SqlAlchemySqlStoreConfig): class SqliteSqlStoreConfig(SqlAlchemySqlStoreConfig):
type: Literal["sqlite"] = SqlStoreType.sqlite.value type: Literal[SqlStoreType.sqlite] = SqlStoreType.sqlite
db_path: str = Field( db_path: str = Field(
default=(RUNTIME_BASE_DIR / "sqlstore.db").as_posix(), default=(RUNTIME_BASE_DIR / "sqlstore.db").as_posix(),
description="Database path, e.g. ~/.llama/distributions/ollama/sqlstore.db", description="Database path, e.g. ~/.llama/distributions/ollama/sqlstore.db",
@ -59,7 +58,7 @@ class SqliteSqlStoreConfig(SqlAlchemySqlStoreConfig):
class PostgresSqlStoreConfig(SqlAlchemySqlStoreConfig): class PostgresSqlStoreConfig(SqlAlchemySqlStoreConfig):
type: Literal["postgres"] = SqlStoreType.postgres.value type: Literal[SqlStoreType.postgres] = SqlStoreType.postgres
host: str = "localhost" host: str = "localhost"
port: int = 5432 port: int = 5432
db: str = "llamastack" db: str = "llamastack"
@ -107,7 +106,7 @@ def get_pip_packages(store_config: dict | SqlStoreConfig) -> list[str]:
def sqlstore_impl(config: SqlStoreConfig) -> SqlStore: def sqlstore_impl(config: SqlStoreConfig) -> SqlStore:
if config.type in [SqlStoreType.sqlite.value, SqlStoreType.postgres.value]: if config.type in [SqlStoreType.sqlite, SqlStoreType.postgres]:
from .sqlalchemy_sqlstore import SqlAlchemySqlStoreImpl from .sqlalchemy_sqlstore import SqlAlchemySqlStoreImpl
impl = SqlAlchemySqlStoreImpl(config) impl = SqlAlchemySqlStoreImpl(config)

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,173 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
import tempfile
from unittest.mock import patch
import pytest
from llama_stack.distribution.access_control.access_control import default_policy
from llama_stack.distribution.datatypes import User
from llama_stack.providers.utils.sqlstore.api import ColumnType
from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore
from llama_stack.providers.utils.sqlstore.sqlalchemy_sqlstore import SqlAlchemySqlStoreImpl
from llama_stack.providers.utils.sqlstore.sqlstore import PostgresSqlStoreConfig, SqliteSqlStoreConfig
def get_postgres_config():
"""Get PostgreSQL configuration if tests are enabled."""
return PostgresSqlStoreConfig(
host=os.environ.get("POSTGRES_HOST", "localhost"),
port=int(os.environ.get("POSTGRES_PORT", "5432")),
db=os.environ.get("POSTGRES_DB", "llamastack"),
user=os.environ.get("POSTGRES_USER", "llamastack"),
password=os.environ.get("POSTGRES_PASSWORD", "llamastack"),
)
def get_sqlite_config():
"""Get SQLite configuration with temporary database."""
tmp_file = tempfile.NamedTemporaryFile(suffix=".db", delete=False)
tmp_file.close()
return SqliteSqlStoreConfig(db_path=tmp_file.name), tmp_file.name
@pytest.mark.asyncio
@pytest.mark.parametrize(
"backend_config",
[
pytest.param(
("postgres", get_postgres_config),
marks=pytest.mark.skipif(
not os.environ.get("ENABLE_POSTGRES_TESTS"),
reason="PostgreSQL tests require ENABLE_POSTGRES_TESTS environment variable",
),
id="postgres",
),
pytest.param(("sqlite", get_sqlite_config), id="sqlite"),
],
)
@patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user")
async def test_json_comparison(mock_get_authenticated_user, backend_config):
"""Test that JSON column comparisons work correctly for both PostgreSQL and SQLite"""
backend_name, config_func = backend_config
# Handle different config types
if backend_name == "postgres":
config = config_func()
cleanup_path = None
else: # sqlite
config, cleanup_path = config_func()
try:
base_sqlstore = SqlAlchemySqlStoreImpl(config)
authorized_store = AuthorizedSqlStore(base_sqlstore)
# Create test table
table_name = f"test_json_comparison_{backend_name}"
await authorized_store.create_table(
table=table_name,
schema={
"id": ColumnType.STRING,
"data": ColumnType.STRING,
},
)
try:
# Test with no authenticated user (should handle JSON null comparison)
mock_get_authenticated_user.return_value = None
# Insert some test data
await authorized_store.insert(table_name, {"id": "1", "data": "public_data"})
# Test fetching with no user - should not error on JSON comparison
result = await authorized_store.fetch_all(table_name, policy=default_policy())
assert len(result.data) == 1
assert result.data[0]["id"] == "1"
assert result.data[0]["access_attributes"] is None
# Test with authenticated user
test_user = User("test-user", {"roles": ["admin"]})
mock_get_authenticated_user.return_value = test_user
# Insert data with user attributes
await authorized_store.insert(table_name, {"id": "2", "data": "admin_data"})
# Fetch all - admin should see both
result = await authorized_store.fetch_all(table_name, policy=default_policy())
assert len(result.data) == 2
# Test with non-admin user
regular_user = User("regular-user", {"roles": ["user"]})
mock_get_authenticated_user.return_value = regular_user
# Should only see public record
result = await authorized_store.fetch_all(table_name, policy=default_policy())
assert len(result.data) == 1
assert result.data[0]["id"] == "1"
# Test the category missing branch: user with multiple attributes
multi_user = User("multi-user", {"roles": ["admin"], "teams": ["dev"]})
mock_get_authenticated_user.return_value = multi_user
# Insert record with multi-user (has both roles and teams)
await authorized_store.insert(table_name, {"id": "3", "data": "multi_user_data"})
# Test different user types to create records with different attribute patterns
# Record with only roles (teams category will be missing)
roles_only_user = User("roles-user", {"roles": ["admin"]})
mock_get_authenticated_user.return_value = roles_only_user
await authorized_store.insert(table_name, {"id": "4", "data": "roles_only_data"})
# Record with only teams (roles category will be missing)
teams_only_user = User("teams-user", {"teams": ["dev"]})
mock_get_authenticated_user.return_value = teams_only_user
await authorized_store.insert(table_name, {"id": "5", "data": "teams_only_data"})
# Record with different roles/teams (shouldn't match our test user)
different_user = User("different-user", {"roles": ["user"], "teams": ["qa"]})
mock_get_authenticated_user.return_value = different_user
await authorized_store.insert(table_name, {"id": "6", "data": "different_user_data"})
# Now test with the multi-user who has both roles=admin and teams=dev
mock_get_authenticated_user.return_value = multi_user
result = await authorized_store.fetch_all(table_name, policy=default_policy())
# Should see:
# - public record (1) - no access_attributes
# - admin record (2) - user matches roles=admin, teams missing (allowed)
# - multi_user record (3) - user matches both roles=admin and teams=dev
# - roles_only record (4) - user matches roles=admin, teams missing (allowed)
# - teams_only record (5) - user matches teams=dev, roles missing (allowed)
# Should NOT see:
# - different_user record (6) - user doesn't match roles=user or teams=qa
expected_ids = {"1", "2", "3", "4", "5"}
actual_ids = {record["id"] for record in result.data}
assert actual_ids == expected_ids, f"Expected to see records {expected_ids} but got {actual_ids}"
# Verify the category missing logic specifically
# Records 4 and 5 test the "category missing" branch where one attribute category is missing
category_test_ids = {record["id"] for record in result.data if record["id"] in ["4", "5"]}
assert category_test_ids == {"4", "5"}, (
f"Category missing logic failed: expected 4,5 but got {category_test_ids}"
)
finally:
# Clean up records
for record_id in ["1", "2", "3", "4", "5", "6"]:
try:
await base_sqlstore.delete(table_name, {"id": record_id})
except Exception:
pass
finally:
# Clean up temporary SQLite database file if needed
if cleanup_path:
try:
os.unlink(cleanup_path)
except OSError:
pass

View file

@ -104,19 +104,17 @@ async def test_sql_policy_consistency(mock_get_authenticated_user):
# Test scenarios with different access control patterns # Test scenarios with different access control patterns
test_scenarios = [ test_scenarios = [
# Scenario 1: Public record (no access control) # Scenario 1: Public record (no access control - represents None user insert)
{"id": "1", "name": "public", "access_attributes": None}, {"id": "1", "name": "public", "access_attributes": None},
# Scenario 2: Empty access control (should be treated as public) # Scenario 2: Record with roles requirement
{"id": "2", "name": "empty", "access_attributes": {}}, {"id": "2", "name": "admin-only", "access_attributes": {"roles": ["admin"]}},
# Scenario 3: Record with roles requirement # Scenario 3: Record with multiple attribute categories
{"id": "3", "name": "admin-only", "access_attributes": {"roles": ["admin"]}}, {"id": "3", "name": "admin-ml-team", "access_attributes": {"roles": ["admin"], "teams": ["ml-team"]}},
# Scenario 4: Record with multiple attribute categories # Scenario 4: Record with teams only (missing roles category)
{"id": "4", "name": "admin-ml-team", "access_attributes": {"roles": ["admin"], "teams": ["ml-team"]}}, {"id": "4", "name": "ml-team-only", "access_attributes": {"teams": ["ml-team"]}},
# Scenario 5: Record with teams only (missing roles category) # Scenario 5: Record with roles and projects
{"id": "5", "name": "ml-team-only", "access_attributes": {"teams": ["ml-team"]}},
# Scenario 6: Record with roles and projects
{ {
"id": "6", "id": "5",
"name": "admin-project-x", "name": "admin-project-x",
"access_attributes": {"roles": ["admin"], "projects": ["project-x"]}, "access_attributes": {"roles": ["admin"], "projects": ["project-x"]},
}, },