diff --git a/.github/workflows/integration-sql-store-tests.yml b/.github/workflows/integration-sql-store-tests.yml new file mode 100644 index 000000000..aeeecf395 --- /dev/null +++ b/.github/workflows/integration-sql-store-tests.yml @@ -0,0 +1,70 @@ +name: SqlStore Integration Tests + +on: + push: + branches: [ main ] + pull_request: + branches: [ main ] + paths: + - 'llama_stack/providers/utils/sqlstore/**' + - 'tests/integration/sqlstore/**' + - 'uv.lock' + - 'pyproject.toml' + - 'requirements.txt' + - '.github/workflows/integration-sql-store-tests.yml' # This workflow + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +jobs: + test-postgres: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.12", "3.13"] + fail-fast: false + + services: + postgres: + image: postgres:15 + env: + POSTGRES_USER: llamastack + POSTGRES_PASSWORD: llamastack + POSTGRES_DB: llamastack + ports: + - 5432:5432 + options: >- + --health-cmd pg_isready + --health-interval 10s + --health-timeout 5s + --health-retries 5 + + steps: + - name: Checkout repository + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + + - name: Install dependencies + uses: ./.github/actions/setup-runner + with: + python-version: ${{ matrix.python-version }} + + - name: Run SqlStore Integration Tests + env: + ENABLE_POSTGRES_TESTS: "true" + POSTGRES_HOST: localhost + POSTGRES_PORT: 5432 + POSTGRES_DB: llamastack + POSTGRES_USER: llamastack + POSTGRES_PASSWORD: llamastack + run: | + uv run pytest -sv tests/integration/providers/utils/sqlstore/ + + - name: Upload test logs + if: ${{ always() }} + uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2 + with: + name: postgres-test-logs-${{ github.run_id }}-${{ github.run_attempt }}-${{ matrix.python-version }} + path: | + *.log + retention-days: 1 diff --git a/llama_stack/providers/utils/sqlstore/authorized_sqlstore.py b/llama_stack/providers/utils/sqlstore/authorized_sqlstore.py index 65401382f..5dff7f122 100644 --- a/llama_stack/providers/utils/sqlstore/authorized_sqlstore.py +++ b/llama_stack/providers/utils/sqlstore/authorized_sqlstore.py @@ -15,6 +15,7 @@ from llama_stack.distribution.request_headers import get_authenticated_user from llama_stack.log import get_logger from .api import ColumnDefinition, ColumnType, PaginatedResponse, SqlStore +from .sqlstore import SqlStoreType logger = get_logger(name=__name__, category="authorized_sqlstore") @@ -71,9 +72,18 @@ class AuthorizedSqlStore: :param sql_store: Base SqlStore implementation to wrap """ self.sql_store = sql_store - + self._detect_database_type() self._validate_sql_optimized_policy() + def _detect_database_type(self) -> None: + """Detect the database type from the underlying SQL store.""" + if not hasattr(self.sql_store, "config"): + raise ValueError("SqlStore must have a config attribute to be used with AuthorizedSqlStore") + + self.database_type = self.sql_store.config.type + if self.database_type not in [SqlStoreType.postgres, SqlStoreType.sqlite]: + raise ValueError(f"Unsupported database type: {self.database_type}") + def _validate_sql_optimized_policy(self) -> None: """Validate that SQL_OPTIMIZED_POLICY matches the actual default_policy(). @@ -181,6 +191,50 @@ class AuthorizedSqlStore: else: return self._build_conservative_where_clause() + def _json_extract(self, column: str, path: str) -> str: + """Extract JSON value (keeping JSON type). + + Args: + column: The JSON column name + path: The JSON path (e.g., 'roles', 'teams') + + Returns: + SQL expression to extract JSON value + """ + if self.database_type == SqlStoreType.postgres: + return f"{column}->'{path}'" + elif self.database_type == SqlStoreType.sqlite: + return f"JSON_EXTRACT({column}, '$.{path}')" + else: + raise ValueError(f"Unsupported database type: {self.database_type}") + + def _json_extract_text(self, column: str, path: str) -> str: + """Extract JSON value as text. + + Args: + column: The JSON column name + path: The JSON path (e.g., 'roles', 'teams') + + Returns: + SQL expression to extract JSON value as text + """ + if self.database_type == SqlStoreType.postgres: + return f"{column}->>'{path}'" + elif self.database_type == SqlStoreType.sqlite: + return f"JSON_EXTRACT({column}, '$.{path}')" + else: + raise ValueError(f"Unsupported database type: {self.database_type}") + + def _get_public_access_conditions(self) -> list[str]: + """Get the SQL conditions for public access.""" + if self.database_type == SqlStoreType.postgres: + # Postgres stores JSON null as 'null' + return ["access_attributes::text = 'null'"] + elif self.database_type == SqlStoreType.sqlite: + return ["access_attributes = 'null'"] + else: + raise ValueError(f"Unsupported database type: {self.database_type}") + def _build_default_policy_where_clause(self) -> str: """Build SQL WHERE clause for the default policy. @@ -189,30 +243,33 @@ class AuthorizedSqlStore: """ current_user = get_authenticated_user() + base_conditions = self._get_public_access_conditions() if not current_user or not current_user.attributes: - return "(access_attributes IS NULL OR access_attributes = 'null' OR access_attributes = '{}')" + # Only allow public records + return f"({' OR '.join(base_conditions)})" else: - base_conditions = ["access_attributes IS NULL", "access_attributes = 'null'", "access_attributes = '{}'"] - user_attr_conditions = [] for attr_key, user_values in current_user.attributes.items(): if user_values: value_conditions = [] for value in user_values: - value_conditions.append(f"JSON_EXTRACT(access_attributes, '$.{attr_key}') LIKE '%\"{value}\"%'") + # Check if JSON array contains the value + escaped_value = value.replace("'", "''") + json_text = self._json_extract_text("access_attributes", attr_key) + value_conditions.append(f"({json_text} LIKE '%\"{escaped_value}\"%')") if value_conditions: - category_missing = f"JSON_EXTRACT(access_attributes, '$.{attr_key}') IS NULL" + # Check if the category is missing (NULL) + category_missing = f"{self._json_extract('access_attributes', attr_key)} IS NULL" user_matches_category = f"({' OR '.join(value_conditions)})" user_attr_conditions.append(f"({category_missing} OR {user_matches_category})") if user_attr_conditions: all_requirements_met = f"({' AND '.join(user_attr_conditions)})" base_conditions.append(all_requirements_met) - return f"({' OR '.join(base_conditions)})" - else: - return f"({' OR '.join(base_conditions)})" + + return f"({' OR '.join(base_conditions)})" def _build_conservative_where_clause(self) -> str: """Conservative SQL filtering for custom policies. @@ -222,5 +279,8 @@ class AuthorizedSqlStore: current_user = get_authenticated_user() if not current_user: - return "(access_attributes IS NULL OR access_attributes = 'null' OR access_attributes = '{}')" + # Only allow public records + base_conditions = self._get_public_access_conditions() + return f"({' OR '.join(base_conditions)})" + return "1=1" diff --git a/llama_stack/providers/utils/sqlstore/sqlstore.py b/llama_stack/providers/utils/sqlstore/sqlstore.py index 227c5abcd..9f7eefcf5 100644 --- a/llama_stack/providers/utils/sqlstore/sqlstore.py +++ b/llama_stack/providers/utils/sqlstore/sqlstore.py @@ -4,9 +4,8 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. - from abc import abstractmethod -from enum import Enum +from enum import StrEnum from pathlib import Path from typing import Annotated, Literal @@ -19,7 +18,7 @@ from .api import SqlStore sql_store_pip_packages = ["sqlalchemy[asyncio]", "aiosqlite", "asyncpg"] -class SqlStoreType(Enum): +class SqlStoreType(StrEnum): sqlite = "sqlite" postgres = "postgres" @@ -36,7 +35,7 @@ class SqlAlchemySqlStoreConfig(BaseModel): class SqliteSqlStoreConfig(SqlAlchemySqlStoreConfig): - type: Literal["sqlite"] = SqlStoreType.sqlite.value + type: Literal[SqlStoreType.sqlite] = SqlStoreType.sqlite db_path: str = Field( default=(RUNTIME_BASE_DIR / "sqlstore.db").as_posix(), description="Database path, e.g. ~/.llama/distributions/ollama/sqlstore.db", @@ -59,7 +58,7 @@ class SqliteSqlStoreConfig(SqlAlchemySqlStoreConfig): class PostgresSqlStoreConfig(SqlAlchemySqlStoreConfig): - type: Literal["postgres"] = SqlStoreType.postgres.value + type: Literal[SqlStoreType.postgres] = SqlStoreType.postgres host: str = "localhost" port: int = 5432 db: str = "llamastack" @@ -107,7 +106,7 @@ def get_pip_packages(store_config: dict | SqlStoreConfig) -> list[str]: def sqlstore_impl(config: SqlStoreConfig) -> SqlStore: - if config.type in [SqlStoreType.sqlite.value, SqlStoreType.postgres.value]: + if config.type in [SqlStoreType.sqlite, SqlStoreType.postgres]: from .sqlalchemy_sqlstore import SqlAlchemySqlStoreImpl impl = SqlAlchemySqlStoreImpl(config) diff --git a/tests/integration/providers/utils/__init__.py b/tests/integration/providers/utils/__init__.py new file mode 100644 index 000000000..756f351d8 --- /dev/null +++ b/tests/integration/providers/utils/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. diff --git a/tests/integration/providers/utils/sqlstore/__init__.py b/tests/integration/providers/utils/sqlstore/__init__.py new file mode 100644 index 000000000..756f351d8 --- /dev/null +++ b/tests/integration/providers/utils/sqlstore/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. diff --git a/tests/integration/providers/utils/sqlstore/test_authorized_sqlstore.py b/tests/integration/providers/utils/sqlstore/test_authorized_sqlstore.py new file mode 100644 index 000000000..93b4d8905 --- /dev/null +++ b/tests/integration/providers/utils/sqlstore/test_authorized_sqlstore.py @@ -0,0 +1,173 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import os +import tempfile +from unittest.mock import patch + +import pytest + +from llama_stack.distribution.access_control.access_control import default_policy +from llama_stack.distribution.datatypes import User +from llama_stack.providers.utils.sqlstore.api import ColumnType +from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore +from llama_stack.providers.utils.sqlstore.sqlalchemy_sqlstore import SqlAlchemySqlStoreImpl +from llama_stack.providers.utils.sqlstore.sqlstore import PostgresSqlStoreConfig, SqliteSqlStoreConfig + + +def get_postgres_config(): + """Get PostgreSQL configuration if tests are enabled.""" + return PostgresSqlStoreConfig( + host=os.environ.get("POSTGRES_HOST", "localhost"), + port=int(os.environ.get("POSTGRES_PORT", "5432")), + db=os.environ.get("POSTGRES_DB", "llamastack"), + user=os.environ.get("POSTGRES_USER", "llamastack"), + password=os.environ.get("POSTGRES_PASSWORD", "llamastack"), + ) + + +def get_sqlite_config(): + """Get SQLite configuration with temporary database.""" + tmp_file = tempfile.NamedTemporaryFile(suffix=".db", delete=False) + tmp_file.close() + return SqliteSqlStoreConfig(db_path=tmp_file.name), tmp_file.name + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "backend_config", + [ + pytest.param( + ("postgres", get_postgres_config), + marks=pytest.mark.skipif( + not os.environ.get("ENABLE_POSTGRES_TESTS"), + reason="PostgreSQL tests require ENABLE_POSTGRES_TESTS environment variable", + ), + id="postgres", + ), + pytest.param(("sqlite", get_sqlite_config), id="sqlite"), + ], +) +@patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user") +async def test_json_comparison(mock_get_authenticated_user, backend_config): + """Test that JSON column comparisons work correctly for both PostgreSQL and SQLite""" + backend_name, config_func = backend_config + + # Handle different config types + if backend_name == "postgres": + config = config_func() + cleanup_path = None + else: # sqlite + config, cleanup_path = config_func() + + try: + base_sqlstore = SqlAlchemySqlStoreImpl(config) + authorized_store = AuthorizedSqlStore(base_sqlstore) + + # Create test table + table_name = f"test_json_comparison_{backend_name}" + await authorized_store.create_table( + table=table_name, + schema={ + "id": ColumnType.STRING, + "data": ColumnType.STRING, + }, + ) + + try: + # Test with no authenticated user (should handle JSON null comparison) + mock_get_authenticated_user.return_value = None + + # Insert some test data + await authorized_store.insert(table_name, {"id": "1", "data": "public_data"}) + + # Test fetching with no user - should not error on JSON comparison + result = await authorized_store.fetch_all(table_name, policy=default_policy()) + assert len(result.data) == 1 + assert result.data[0]["id"] == "1" + assert result.data[0]["access_attributes"] is None + + # Test with authenticated user + test_user = User("test-user", {"roles": ["admin"]}) + mock_get_authenticated_user.return_value = test_user + + # Insert data with user attributes + await authorized_store.insert(table_name, {"id": "2", "data": "admin_data"}) + + # Fetch all - admin should see both + result = await authorized_store.fetch_all(table_name, policy=default_policy()) + assert len(result.data) == 2 + + # Test with non-admin user + regular_user = User("regular-user", {"roles": ["user"]}) + mock_get_authenticated_user.return_value = regular_user + + # Should only see public record + result = await authorized_store.fetch_all(table_name, policy=default_policy()) + assert len(result.data) == 1 + assert result.data[0]["id"] == "1" + + # Test the category missing branch: user with multiple attributes + multi_user = User("multi-user", {"roles": ["admin"], "teams": ["dev"]}) + mock_get_authenticated_user.return_value = multi_user + + # Insert record with multi-user (has both roles and teams) + await authorized_store.insert(table_name, {"id": "3", "data": "multi_user_data"}) + + # Test different user types to create records with different attribute patterns + # Record with only roles (teams category will be missing) + roles_only_user = User("roles-user", {"roles": ["admin"]}) + mock_get_authenticated_user.return_value = roles_only_user + await authorized_store.insert(table_name, {"id": "4", "data": "roles_only_data"}) + + # Record with only teams (roles category will be missing) + teams_only_user = User("teams-user", {"teams": ["dev"]}) + mock_get_authenticated_user.return_value = teams_only_user + await authorized_store.insert(table_name, {"id": "5", "data": "teams_only_data"}) + + # Record with different roles/teams (shouldn't match our test user) + different_user = User("different-user", {"roles": ["user"], "teams": ["qa"]}) + mock_get_authenticated_user.return_value = different_user + await authorized_store.insert(table_name, {"id": "6", "data": "different_user_data"}) + + # Now test with the multi-user who has both roles=admin and teams=dev + mock_get_authenticated_user.return_value = multi_user + result = await authorized_store.fetch_all(table_name, policy=default_policy()) + + # Should see: + # - public record (1) - no access_attributes + # - admin record (2) - user matches roles=admin, teams missing (allowed) + # - multi_user record (3) - user matches both roles=admin and teams=dev + # - roles_only record (4) - user matches roles=admin, teams missing (allowed) + # - teams_only record (5) - user matches teams=dev, roles missing (allowed) + # Should NOT see: + # - different_user record (6) - user doesn't match roles=user or teams=qa + expected_ids = {"1", "2", "3", "4", "5"} + actual_ids = {record["id"] for record in result.data} + assert actual_ids == expected_ids, f"Expected to see records {expected_ids} but got {actual_ids}" + + # Verify the category missing logic specifically + # Records 4 and 5 test the "category missing" branch where one attribute category is missing + category_test_ids = {record["id"] for record in result.data if record["id"] in ["4", "5"]} + assert category_test_ids == {"4", "5"}, ( + f"Category missing logic failed: expected 4,5 but got {category_test_ids}" + ) + + finally: + # Clean up records + for record_id in ["1", "2", "3", "4", "5", "6"]: + try: + await base_sqlstore.delete(table_name, {"id": record_id}) + except Exception: + pass + + finally: + # Clean up temporary SQLite database file if needed + if cleanup_path: + try: + os.unlink(cleanup_path) + except OSError: + pass diff --git a/tests/unit/utils/test_authorized_sqlstore.py b/tests/unit/utils/test_authorized_sqlstore.py index b457176a7..1624c0ba7 100644 --- a/tests/unit/utils/test_authorized_sqlstore.py +++ b/tests/unit/utils/test_authorized_sqlstore.py @@ -104,19 +104,17 @@ async def test_sql_policy_consistency(mock_get_authenticated_user): # Test scenarios with different access control patterns test_scenarios = [ - # Scenario 1: Public record (no access control) + # Scenario 1: Public record (no access control - represents None user insert) {"id": "1", "name": "public", "access_attributes": None}, - # Scenario 2: Empty access control (should be treated as public) - {"id": "2", "name": "empty", "access_attributes": {}}, - # Scenario 3: Record with roles requirement - {"id": "3", "name": "admin-only", "access_attributes": {"roles": ["admin"]}}, - # Scenario 4: Record with multiple attribute categories - {"id": "4", "name": "admin-ml-team", "access_attributes": {"roles": ["admin"], "teams": ["ml-team"]}}, - # Scenario 5: Record with teams only (missing roles category) - {"id": "5", "name": "ml-team-only", "access_attributes": {"teams": ["ml-team"]}}, - # Scenario 6: Record with roles and projects + # Scenario 2: Record with roles requirement + {"id": "2", "name": "admin-only", "access_attributes": {"roles": ["admin"]}}, + # Scenario 3: Record with multiple attribute categories + {"id": "3", "name": "admin-ml-team", "access_attributes": {"roles": ["admin"], "teams": ["ml-team"]}}, + # Scenario 4: Record with teams only (missing roles category) + {"id": "4", "name": "ml-team-only", "access_attributes": {"teams": ["ml-team"]}}, + # Scenario 5: Record with roles and projects { - "id": "6", + "id": "5", "name": "admin-project-x", "access_attributes": {"roles": ["admin"], "projects": ["project-x"]}, },