mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-14 17:16:09 +00:00
fix: authorized sql store with postgres (#2641)
Some checks failed
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 1s
Integration Tests / test-matrix (library, 3.12, datasets) (push) Failing after 4s
Integration Tests / test-matrix (library, 3.12, agents) (push) Failing after 7s
Integration Tests / test-matrix (library, 3.12, scoring) (push) Failing after 5s
Integration Tests / test-matrix (library, 3.13, post_training) (push) Failing after 4s
Integration Tests / test-matrix (library, 3.12, inference) (push) Failing after 10s
Integration Tests / test-matrix (library, 3.12, vector_io) (push) Failing after 9s
Integration Tests / test-matrix (library, 3.13, providers) (push) Failing after 8s
Integration Tests / test-matrix (library, 3.12, tool_runtime) (push) Failing after 6s
Integration Tests / test-matrix (library, 3.12, post_training) (push) Failing after 13s
SqlStore Integration Tests / test-postgres (3.13) (push) Failing after 16s
SqlStore Integration Tests / test-postgres (3.12) (push) Failing after 19s
Integration Tests / test-matrix (library, 3.13, datasets) (push) Failing after 15s
Integration Tests / test-matrix (server, 3.12, datasets) (push) Failing after 8s
Integration Tests / test-matrix (library, 3.12, inspect) (push) Failing after 12s
Integration Tests / test-matrix (server, 3.13, inference) (push) Failing after 9s
Integration Tests / test-matrix (library, 3.13, tool_runtime) (push) Failing after 13s
Integration Tests / test-matrix (library, 3.13, scoring) (push) Failing after 12s
Integration Tests / test-matrix (server, 3.12, providers) (push) Failing after 8s
Integration Tests / test-matrix (library, 3.13, vector_io) (push) Failing after 10s
Integration Tests / test-matrix (server, 3.12, agents) (push) Failing after 12s
Integration Tests / test-matrix (library, 3.13, inspect) (push) Failing after 14s
Integration Tests / test-matrix (server, 3.12, tool_runtime) (push) Failing after 10s
Integration Tests / test-matrix (server, 3.13, datasets) (push) Failing after 8s
Integration Tests / test-matrix (server, 3.13, inspect) (push) Failing after 12s
Integration Tests / test-matrix (server, 3.12, inference) (push) Failing after 6s
Integration Tests / test-matrix (server, 3.13, providers) (push) Failing after 10s
Integration Tests / test-matrix (server, 3.13, scoring) (push) Failing after 9s
Integration Tests / test-matrix (library, 3.13, inference) (push) Failing after 11s
Integration Tests / test-matrix (server, 3.12, inspect) (push) Failing after 12s
Integration Tests / test-matrix (server, 3.12, scoring) (push) Failing after 11s
Integration Tests / test-matrix (library, 3.12, providers) (push) Failing after 13s
Integration Tests / test-matrix (server, 3.13, agents) (push) Failing after 10s
Vector IO Integration Tests / test-matrix (3.12, inline::faiss) (push) Failing after 8s
Integration Tests / test-matrix (server, 3.13, post_training) (push) Failing after 11s
Integration Tests / test-matrix (server, 3.13, vector_io) (push) Failing after 7s
Vector IO Integration Tests / test-matrix (3.12, inline::milvus) (push) Failing after 8s
Integration Tests / test-matrix (library, 3.13, agents) (push) Failing after 13s
Integration Tests / test-matrix (server, 3.12, vector_io) (push) Failing after 14s
Integration Tests / test-matrix (server, 3.12, post_training) (push) Failing after 14s
Integration Tests / test-matrix (server, 3.13, tool_runtime) (push) Failing after 8s
Vector IO Integration Tests / test-matrix (3.12, inline::sqlite-vec) (push) Failing after 25s
Vector IO Integration Tests / test-matrix (3.12, remote::chromadb) (push) Failing after 23s
Vector IO Integration Tests / test-matrix (3.12, remote::pgvector) (push) Failing after 28s
Vector IO Integration Tests / test-matrix (3.13, inline::faiss) (push) Failing after 27s
Vector IO Integration Tests / test-matrix (3.13, inline::milvus) (push) Failing after 12s
Vector IO Integration Tests / test-matrix (3.13, inline::sqlite-vec) (push) Failing after 10s
Vector IO Integration Tests / test-matrix (3.13, remote::chromadb) (push) Failing after 6s
Vector IO Integration Tests / test-matrix (3.13, remote::pgvector) (push) Failing after 5s
Test Llama Stack Build / generate-matrix (push) Successful in 5s
Python Package Build Test / build (3.12) (push) Failing after 1s
Test External Providers / test-external-providers (venv) (push) Failing after 3s
Python Package Build Test / build (3.13) (push) Failing after 3s
Update ReadTheDocs / update-readthedocs (push) Failing after 3s
Test Llama Stack Build / build (push) Failing after 4s
Unit Tests / unit-tests (3.12) (push) Failing after 4s
Unit Tests / unit-tests (3.13) (push) Failing after 7s
Test Llama Stack Build / build-single-provider (push) Failing after 44s
Test Llama Stack Build / build-ubi9-container-distribution (push) Failing after 41s
Test Llama Stack Build / build-custom-container-distribution (push) Failing after 43s
Pre-commit / pre-commit (push) Successful in 1m34s
Some checks failed
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 1s
Integration Tests / test-matrix (library, 3.12, datasets) (push) Failing after 4s
Integration Tests / test-matrix (library, 3.12, agents) (push) Failing after 7s
Integration Tests / test-matrix (library, 3.12, scoring) (push) Failing after 5s
Integration Tests / test-matrix (library, 3.13, post_training) (push) Failing after 4s
Integration Tests / test-matrix (library, 3.12, inference) (push) Failing after 10s
Integration Tests / test-matrix (library, 3.12, vector_io) (push) Failing after 9s
Integration Tests / test-matrix (library, 3.13, providers) (push) Failing after 8s
Integration Tests / test-matrix (library, 3.12, tool_runtime) (push) Failing after 6s
Integration Tests / test-matrix (library, 3.12, post_training) (push) Failing after 13s
SqlStore Integration Tests / test-postgres (3.13) (push) Failing after 16s
SqlStore Integration Tests / test-postgres (3.12) (push) Failing after 19s
Integration Tests / test-matrix (library, 3.13, datasets) (push) Failing after 15s
Integration Tests / test-matrix (server, 3.12, datasets) (push) Failing after 8s
Integration Tests / test-matrix (library, 3.12, inspect) (push) Failing after 12s
Integration Tests / test-matrix (server, 3.13, inference) (push) Failing after 9s
Integration Tests / test-matrix (library, 3.13, tool_runtime) (push) Failing after 13s
Integration Tests / test-matrix (library, 3.13, scoring) (push) Failing after 12s
Integration Tests / test-matrix (server, 3.12, providers) (push) Failing after 8s
Integration Tests / test-matrix (library, 3.13, vector_io) (push) Failing after 10s
Integration Tests / test-matrix (server, 3.12, agents) (push) Failing after 12s
Integration Tests / test-matrix (library, 3.13, inspect) (push) Failing after 14s
Integration Tests / test-matrix (server, 3.12, tool_runtime) (push) Failing after 10s
Integration Tests / test-matrix (server, 3.13, datasets) (push) Failing after 8s
Integration Tests / test-matrix (server, 3.13, inspect) (push) Failing after 12s
Integration Tests / test-matrix (server, 3.12, inference) (push) Failing after 6s
Integration Tests / test-matrix (server, 3.13, providers) (push) Failing after 10s
Integration Tests / test-matrix (server, 3.13, scoring) (push) Failing after 9s
Integration Tests / test-matrix (library, 3.13, inference) (push) Failing after 11s
Integration Tests / test-matrix (server, 3.12, inspect) (push) Failing after 12s
Integration Tests / test-matrix (server, 3.12, scoring) (push) Failing after 11s
Integration Tests / test-matrix (library, 3.12, providers) (push) Failing after 13s
Integration Tests / test-matrix (server, 3.13, agents) (push) Failing after 10s
Vector IO Integration Tests / test-matrix (3.12, inline::faiss) (push) Failing after 8s
Integration Tests / test-matrix (server, 3.13, post_training) (push) Failing after 11s
Integration Tests / test-matrix (server, 3.13, vector_io) (push) Failing after 7s
Vector IO Integration Tests / test-matrix (3.12, inline::milvus) (push) Failing after 8s
Integration Tests / test-matrix (library, 3.13, agents) (push) Failing after 13s
Integration Tests / test-matrix (server, 3.12, vector_io) (push) Failing after 14s
Integration Tests / test-matrix (server, 3.12, post_training) (push) Failing after 14s
Integration Tests / test-matrix (server, 3.13, tool_runtime) (push) Failing after 8s
Vector IO Integration Tests / test-matrix (3.12, inline::sqlite-vec) (push) Failing after 25s
Vector IO Integration Tests / test-matrix (3.12, remote::chromadb) (push) Failing after 23s
Vector IO Integration Tests / test-matrix (3.12, remote::pgvector) (push) Failing after 28s
Vector IO Integration Tests / test-matrix (3.13, inline::faiss) (push) Failing after 27s
Vector IO Integration Tests / test-matrix (3.13, inline::milvus) (push) Failing after 12s
Vector IO Integration Tests / test-matrix (3.13, inline::sqlite-vec) (push) Failing after 10s
Vector IO Integration Tests / test-matrix (3.13, remote::chromadb) (push) Failing after 6s
Vector IO Integration Tests / test-matrix (3.13, remote::pgvector) (push) Failing after 5s
Test Llama Stack Build / generate-matrix (push) Successful in 5s
Python Package Build Test / build (3.12) (push) Failing after 1s
Test External Providers / test-external-providers (venv) (push) Failing after 3s
Python Package Build Test / build (3.13) (push) Failing after 3s
Update ReadTheDocs / update-readthedocs (push) Failing after 3s
Test Llama Stack Build / build (push) Failing after 4s
Unit Tests / unit-tests (3.12) (push) Failing after 4s
Unit Tests / unit-tests (3.13) (push) Failing after 7s
Test Llama Stack Build / build-single-provider (push) Failing after 44s
Test Llama Stack Build / build-ubi9-container-distribution (push) Failing after 41s
Test Llama Stack Build / build-custom-container-distribution (push) Failing after 43s
Pre-commit / pre-commit (push) Successful in 1m34s
# What does this PR do? postgres has different json extract syntax from sqlite ## Test Plan added integration test
This commit is contained in:
parent
5bb3817c49
commit
e9926564bd
7 changed files with 337 additions and 27 deletions
70
.github/workflows/integration-sql-store-tests.yml
vendored
Normal file
70
.github/workflows/integration-sql-store-tests.yml
vendored
Normal file
|
@ -0,0 +1,70 @@
|
||||||
|
name: SqlStore Integration Tests
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches: [ main ]
|
||||||
|
pull_request:
|
||||||
|
branches: [ main ]
|
||||||
|
paths:
|
||||||
|
- 'llama_stack/providers/utils/sqlstore/**'
|
||||||
|
- 'tests/integration/sqlstore/**'
|
||||||
|
- 'uv.lock'
|
||||||
|
- 'pyproject.toml'
|
||||||
|
- 'requirements.txt'
|
||||||
|
- '.github/workflows/integration-sql-store-tests.yml' # This workflow
|
||||||
|
|
||||||
|
concurrency:
|
||||||
|
group: ${{ github.workflow }}-${{ github.ref }}
|
||||||
|
cancel-in-progress: true
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
test-postgres:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
python-version: ["3.12", "3.13"]
|
||||||
|
fail-fast: false
|
||||||
|
|
||||||
|
services:
|
||||||
|
postgres:
|
||||||
|
image: postgres:15
|
||||||
|
env:
|
||||||
|
POSTGRES_USER: llamastack
|
||||||
|
POSTGRES_PASSWORD: llamastack
|
||||||
|
POSTGRES_DB: llamastack
|
||||||
|
ports:
|
||||||
|
- 5432:5432
|
||||||
|
options: >-
|
||||||
|
--health-cmd pg_isready
|
||||||
|
--health-interval 10s
|
||||||
|
--health-timeout 5s
|
||||||
|
--health-retries 5
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Checkout repository
|
||||||
|
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
uses: ./.github/actions/setup-runner
|
||||||
|
with:
|
||||||
|
python-version: ${{ matrix.python-version }}
|
||||||
|
|
||||||
|
- name: Run SqlStore Integration Tests
|
||||||
|
env:
|
||||||
|
ENABLE_POSTGRES_TESTS: "true"
|
||||||
|
POSTGRES_HOST: localhost
|
||||||
|
POSTGRES_PORT: 5432
|
||||||
|
POSTGRES_DB: llamastack
|
||||||
|
POSTGRES_USER: llamastack
|
||||||
|
POSTGRES_PASSWORD: llamastack
|
||||||
|
run: |
|
||||||
|
uv run pytest -sv tests/integration/providers/utils/sqlstore/
|
||||||
|
|
||||||
|
- name: Upload test logs
|
||||||
|
if: ${{ always() }}
|
||||||
|
uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2
|
||||||
|
with:
|
||||||
|
name: postgres-test-logs-${{ github.run_id }}-${{ github.run_attempt }}-${{ matrix.python-version }}
|
||||||
|
path: |
|
||||||
|
*.log
|
||||||
|
retention-days: 1
|
|
@ -15,6 +15,7 @@ from llama_stack.distribution.request_headers import get_authenticated_user
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
from .api import ColumnDefinition, ColumnType, PaginatedResponse, SqlStore
|
from .api import ColumnDefinition, ColumnType, PaginatedResponse, SqlStore
|
||||||
|
from .sqlstore import SqlStoreType
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="authorized_sqlstore")
|
logger = get_logger(name=__name__, category="authorized_sqlstore")
|
||||||
|
|
||||||
|
@ -71,9 +72,18 @@ class AuthorizedSqlStore:
|
||||||
:param sql_store: Base SqlStore implementation to wrap
|
:param sql_store: Base SqlStore implementation to wrap
|
||||||
"""
|
"""
|
||||||
self.sql_store = sql_store
|
self.sql_store = sql_store
|
||||||
|
self._detect_database_type()
|
||||||
self._validate_sql_optimized_policy()
|
self._validate_sql_optimized_policy()
|
||||||
|
|
||||||
|
def _detect_database_type(self) -> None:
|
||||||
|
"""Detect the database type from the underlying SQL store."""
|
||||||
|
if not hasattr(self.sql_store, "config"):
|
||||||
|
raise ValueError("SqlStore must have a config attribute to be used with AuthorizedSqlStore")
|
||||||
|
|
||||||
|
self.database_type = self.sql_store.config.type
|
||||||
|
if self.database_type not in [SqlStoreType.postgres, SqlStoreType.sqlite]:
|
||||||
|
raise ValueError(f"Unsupported database type: {self.database_type}")
|
||||||
|
|
||||||
def _validate_sql_optimized_policy(self) -> None:
|
def _validate_sql_optimized_policy(self) -> None:
|
||||||
"""Validate that SQL_OPTIMIZED_POLICY matches the actual default_policy().
|
"""Validate that SQL_OPTIMIZED_POLICY matches the actual default_policy().
|
||||||
|
|
||||||
|
@ -181,6 +191,50 @@ class AuthorizedSqlStore:
|
||||||
else:
|
else:
|
||||||
return self._build_conservative_where_clause()
|
return self._build_conservative_where_clause()
|
||||||
|
|
||||||
|
def _json_extract(self, column: str, path: str) -> str:
|
||||||
|
"""Extract JSON value (keeping JSON type).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
column: The JSON column name
|
||||||
|
path: The JSON path (e.g., 'roles', 'teams')
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
SQL expression to extract JSON value
|
||||||
|
"""
|
||||||
|
if self.database_type == SqlStoreType.postgres:
|
||||||
|
return f"{column}->'{path}'"
|
||||||
|
elif self.database_type == SqlStoreType.sqlite:
|
||||||
|
return f"JSON_EXTRACT({column}, '$.{path}')"
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported database type: {self.database_type}")
|
||||||
|
|
||||||
|
def _json_extract_text(self, column: str, path: str) -> str:
|
||||||
|
"""Extract JSON value as text.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
column: The JSON column name
|
||||||
|
path: The JSON path (e.g., 'roles', 'teams')
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
SQL expression to extract JSON value as text
|
||||||
|
"""
|
||||||
|
if self.database_type == SqlStoreType.postgres:
|
||||||
|
return f"{column}->>'{path}'"
|
||||||
|
elif self.database_type == SqlStoreType.sqlite:
|
||||||
|
return f"JSON_EXTRACT({column}, '$.{path}')"
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported database type: {self.database_type}")
|
||||||
|
|
||||||
|
def _get_public_access_conditions(self) -> list[str]:
|
||||||
|
"""Get the SQL conditions for public access."""
|
||||||
|
if self.database_type == SqlStoreType.postgres:
|
||||||
|
# Postgres stores JSON null as 'null'
|
||||||
|
return ["access_attributes::text = 'null'"]
|
||||||
|
elif self.database_type == SqlStoreType.sqlite:
|
||||||
|
return ["access_attributes = 'null'"]
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported database type: {self.database_type}")
|
||||||
|
|
||||||
def _build_default_policy_where_clause(self) -> str:
|
def _build_default_policy_where_clause(self) -> str:
|
||||||
"""Build SQL WHERE clause for the default policy.
|
"""Build SQL WHERE clause for the default policy.
|
||||||
|
|
||||||
|
@ -189,29 +243,32 @@ class AuthorizedSqlStore:
|
||||||
"""
|
"""
|
||||||
current_user = get_authenticated_user()
|
current_user = get_authenticated_user()
|
||||||
|
|
||||||
|
base_conditions = self._get_public_access_conditions()
|
||||||
if not current_user or not current_user.attributes:
|
if not current_user or not current_user.attributes:
|
||||||
return "(access_attributes IS NULL OR access_attributes = 'null' OR access_attributes = '{}')"
|
# Only allow public records
|
||||||
|
return f"({' OR '.join(base_conditions)})"
|
||||||
else:
|
else:
|
||||||
base_conditions = ["access_attributes IS NULL", "access_attributes = 'null'", "access_attributes = '{}'"]
|
|
||||||
|
|
||||||
user_attr_conditions = []
|
user_attr_conditions = []
|
||||||
|
|
||||||
for attr_key, user_values in current_user.attributes.items():
|
for attr_key, user_values in current_user.attributes.items():
|
||||||
if user_values:
|
if user_values:
|
||||||
value_conditions = []
|
value_conditions = []
|
||||||
for value in user_values:
|
for value in user_values:
|
||||||
value_conditions.append(f"JSON_EXTRACT(access_attributes, '$.{attr_key}') LIKE '%\"{value}\"%'")
|
# Check if JSON array contains the value
|
||||||
|
escaped_value = value.replace("'", "''")
|
||||||
|
json_text = self._json_extract_text("access_attributes", attr_key)
|
||||||
|
value_conditions.append(f"({json_text} LIKE '%\"{escaped_value}\"%')")
|
||||||
|
|
||||||
if value_conditions:
|
if value_conditions:
|
||||||
category_missing = f"JSON_EXTRACT(access_attributes, '$.{attr_key}') IS NULL"
|
# Check if the category is missing (NULL)
|
||||||
|
category_missing = f"{self._json_extract('access_attributes', attr_key)} IS NULL"
|
||||||
user_matches_category = f"({' OR '.join(value_conditions)})"
|
user_matches_category = f"({' OR '.join(value_conditions)})"
|
||||||
user_attr_conditions.append(f"({category_missing} OR {user_matches_category})")
|
user_attr_conditions.append(f"({category_missing} OR {user_matches_category})")
|
||||||
|
|
||||||
if user_attr_conditions:
|
if user_attr_conditions:
|
||||||
all_requirements_met = f"({' AND '.join(user_attr_conditions)})"
|
all_requirements_met = f"({' AND '.join(user_attr_conditions)})"
|
||||||
base_conditions.append(all_requirements_met)
|
base_conditions.append(all_requirements_met)
|
||||||
return f"({' OR '.join(base_conditions)})"
|
|
||||||
else:
|
|
||||||
return f"({' OR '.join(base_conditions)})"
|
return f"({' OR '.join(base_conditions)})"
|
||||||
|
|
||||||
def _build_conservative_where_clause(self) -> str:
|
def _build_conservative_where_clause(self) -> str:
|
||||||
|
@ -222,5 +279,8 @@ class AuthorizedSqlStore:
|
||||||
current_user = get_authenticated_user()
|
current_user = get_authenticated_user()
|
||||||
|
|
||||||
if not current_user:
|
if not current_user:
|
||||||
return "(access_attributes IS NULL OR access_attributes = 'null' OR access_attributes = '{}')"
|
# Only allow public records
|
||||||
|
base_conditions = self._get_public_access_conditions()
|
||||||
|
return f"({' OR '.join(base_conditions)})"
|
||||||
|
|
||||||
return "1=1"
|
return "1=1"
|
||||||
|
|
|
@ -4,9 +4,8 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from enum import Enum
|
from enum import StrEnum
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Annotated, Literal
|
from typing import Annotated, Literal
|
||||||
|
|
||||||
|
@ -19,7 +18,7 @@ from .api import SqlStore
|
||||||
sql_store_pip_packages = ["sqlalchemy[asyncio]", "aiosqlite", "asyncpg"]
|
sql_store_pip_packages = ["sqlalchemy[asyncio]", "aiosqlite", "asyncpg"]
|
||||||
|
|
||||||
|
|
||||||
class SqlStoreType(Enum):
|
class SqlStoreType(StrEnum):
|
||||||
sqlite = "sqlite"
|
sqlite = "sqlite"
|
||||||
postgres = "postgres"
|
postgres = "postgres"
|
||||||
|
|
||||||
|
@ -36,7 +35,7 @@ class SqlAlchemySqlStoreConfig(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
class SqliteSqlStoreConfig(SqlAlchemySqlStoreConfig):
|
class SqliteSqlStoreConfig(SqlAlchemySqlStoreConfig):
|
||||||
type: Literal["sqlite"] = SqlStoreType.sqlite.value
|
type: Literal[SqlStoreType.sqlite] = SqlStoreType.sqlite
|
||||||
db_path: str = Field(
|
db_path: str = Field(
|
||||||
default=(RUNTIME_BASE_DIR / "sqlstore.db").as_posix(),
|
default=(RUNTIME_BASE_DIR / "sqlstore.db").as_posix(),
|
||||||
description="Database path, e.g. ~/.llama/distributions/ollama/sqlstore.db",
|
description="Database path, e.g. ~/.llama/distributions/ollama/sqlstore.db",
|
||||||
|
@ -59,7 +58,7 @@ class SqliteSqlStoreConfig(SqlAlchemySqlStoreConfig):
|
||||||
|
|
||||||
|
|
||||||
class PostgresSqlStoreConfig(SqlAlchemySqlStoreConfig):
|
class PostgresSqlStoreConfig(SqlAlchemySqlStoreConfig):
|
||||||
type: Literal["postgres"] = SqlStoreType.postgres.value
|
type: Literal[SqlStoreType.postgres] = SqlStoreType.postgres
|
||||||
host: str = "localhost"
|
host: str = "localhost"
|
||||||
port: int = 5432
|
port: int = 5432
|
||||||
db: str = "llamastack"
|
db: str = "llamastack"
|
||||||
|
@ -107,7 +106,7 @@ def get_pip_packages(store_config: dict | SqlStoreConfig) -> list[str]:
|
||||||
|
|
||||||
|
|
||||||
def sqlstore_impl(config: SqlStoreConfig) -> SqlStore:
|
def sqlstore_impl(config: SqlStoreConfig) -> SqlStore:
|
||||||
if config.type in [SqlStoreType.sqlite.value, SqlStoreType.postgres.value]:
|
if config.type in [SqlStoreType.sqlite, SqlStoreType.postgres]:
|
||||||
from .sqlalchemy_sqlstore import SqlAlchemySqlStoreImpl
|
from .sqlalchemy_sqlstore import SqlAlchemySqlStoreImpl
|
||||||
|
|
||||||
impl = SqlAlchemySqlStoreImpl(config)
|
impl = SqlAlchemySqlStoreImpl(config)
|
||||||
|
|
5
tests/integration/providers/utils/__init__.py
Normal file
5
tests/integration/providers/utils/__init__.py
Normal file
|
@ -0,0 +1,5 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
5
tests/integration/providers/utils/sqlstore/__init__.py
Normal file
5
tests/integration/providers/utils/sqlstore/__init__.py
Normal file
|
@ -0,0 +1,5 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
|
@ -0,0 +1,173 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from llama_stack.distribution.access_control.access_control import default_policy
|
||||||
|
from llama_stack.distribution.datatypes import User
|
||||||
|
from llama_stack.providers.utils.sqlstore.api import ColumnType
|
||||||
|
from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore
|
||||||
|
from llama_stack.providers.utils.sqlstore.sqlalchemy_sqlstore import SqlAlchemySqlStoreImpl
|
||||||
|
from llama_stack.providers.utils.sqlstore.sqlstore import PostgresSqlStoreConfig, SqliteSqlStoreConfig
|
||||||
|
|
||||||
|
|
||||||
|
def get_postgres_config():
|
||||||
|
"""Get PostgreSQL configuration if tests are enabled."""
|
||||||
|
return PostgresSqlStoreConfig(
|
||||||
|
host=os.environ.get("POSTGRES_HOST", "localhost"),
|
||||||
|
port=int(os.environ.get("POSTGRES_PORT", "5432")),
|
||||||
|
db=os.environ.get("POSTGRES_DB", "llamastack"),
|
||||||
|
user=os.environ.get("POSTGRES_USER", "llamastack"),
|
||||||
|
password=os.environ.get("POSTGRES_PASSWORD", "llamastack"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_sqlite_config():
|
||||||
|
"""Get SQLite configuration with temporary database."""
|
||||||
|
tmp_file = tempfile.NamedTemporaryFile(suffix=".db", delete=False)
|
||||||
|
tmp_file.close()
|
||||||
|
return SqliteSqlStoreConfig(db_path=tmp_file.name), tmp_file.name
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"backend_config",
|
||||||
|
[
|
||||||
|
pytest.param(
|
||||||
|
("postgres", get_postgres_config),
|
||||||
|
marks=pytest.mark.skipif(
|
||||||
|
not os.environ.get("ENABLE_POSTGRES_TESTS"),
|
||||||
|
reason="PostgreSQL tests require ENABLE_POSTGRES_TESTS environment variable",
|
||||||
|
),
|
||||||
|
id="postgres",
|
||||||
|
),
|
||||||
|
pytest.param(("sqlite", get_sqlite_config), id="sqlite"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user")
|
||||||
|
async def test_json_comparison(mock_get_authenticated_user, backend_config):
|
||||||
|
"""Test that JSON column comparisons work correctly for both PostgreSQL and SQLite"""
|
||||||
|
backend_name, config_func = backend_config
|
||||||
|
|
||||||
|
# Handle different config types
|
||||||
|
if backend_name == "postgres":
|
||||||
|
config = config_func()
|
||||||
|
cleanup_path = None
|
||||||
|
else: # sqlite
|
||||||
|
config, cleanup_path = config_func()
|
||||||
|
|
||||||
|
try:
|
||||||
|
base_sqlstore = SqlAlchemySqlStoreImpl(config)
|
||||||
|
authorized_store = AuthorizedSqlStore(base_sqlstore)
|
||||||
|
|
||||||
|
# Create test table
|
||||||
|
table_name = f"test_json_comparison_{backend_name}"
|
||||||
|
await authorized_store.create_table(
|
||||||
|
table=table_name,
|
||||||
|
schema={
|
||||||
|
"id": ColumnType.STRING,
|
||||||
|
"data": ColumnType.STRING,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Test with no authenticated user (should handle JSON null comparison)
|
||||||
|
mock_get_authenticated_user.return_value = None
|
||||||
|
|
||||||
|
# Insert some test data
|
||||||
|
await authorized_store.insert(table_name, {"id": "1", "data": "public_data"})
|
||||||
|
|
||||||
|
# Test fetching with no user - should not error on JSON comparison
|
||||||
|
result = await authorized_store.fetch_all(table_name, policy=default_policy())
|
||||||
|
assert len(result.data) == 1
|
||||||
|
assert result.data[0]["id"] == "1"
|
||||||
|
assert result.data[0]["access_attributes"] is None
|
||||||
|
|
||||||
|
# Test with authenticated user
|
||||||
|
test_user = User("test-user", {"roles": ["admin"]})
|
||||||
|
mock_get_authenticated_user.return_value = test_user
|
||||||
|
|
||||||
|
# Insert data with user attributes
|
||||||
|
await authorized_store.insert(table_name, {"id": "2", "data": "admin_data"})
|
||||||
|
|
||||||
|
# Fetch all - admin should see both
|
||||||
|
result = await authorized_store.fetch_all(table_name, policy=default_policy())
|
||||||
|
assert len(result.data) == 2
|
||||||
|
|
||||||
|
# Test with non-admin user
|
||||||
|
regular_user = User("regular-user", {"roles": ["user"]})
|
||||||
|
mock_get_authenticated_user.return_value = regular_user
|
||||||
|
|
||||||
|
# Should only see public record
|
||||||
|
result = await authorized_store.fetch_all(table_name, policy=default_policy())
|
||||||
|
assert len(result.data) == 1
|
||||||
|
assert result.data[0]["id"] == "1"
|
||||||
|
|
||||||
|
# Test the category missing branch: user with multiple attributes
|
||||||
|
multi_user = User("multi-user", {"roles": ["admin"], "teams": ["dev"]})
|
||||||
|
mock_get_authenticated_user.return_value = multi_user
|
||||||
|
|
||||||
|
# Insert record with multi-user (has both roles and teams)
|
||||||
|
await authorized_store.insert(table_name, {"id": "3", "data": "multi_user_data"})
|
||||||
|
|
||||||
|
# Test different user types to create records with different attribute patterns
|
||||||
|
# Record with only roles (teams category will be missing)
|
||||||
|
roles_only_user = User("roles-user", {"roles": ["admin"]})
|
||||||
|
mock_get_authenticated_user.return_value = roles_only_user
|
||||||
|
await authorized_store.insert(table_name, {"id": "4", "data": "roles_only_data"})
|
||||||
|
|
||||||
|
# Record with only teams (roles category will be missing)
|
||||||
|
teams_only_user = User("teams-user", {"teams": ["dev"]})
|
||||||
|
mock_get_authenticated_user.return_value = teams_only_user
|
||||||
|
await authorized_store.insert(table_name, {"id": "5", "data": "teams_only_data"})
|
||||||
|
|
||||||
|
# Record with different roles/teams (shouldn't match our test user)
|
||||||
|
different_user = User("different-user", {"roles": ["user"], "teams": ["qa"]})
|
||||||
|
mock_get_authenticated_user.return_value = different_user
|
||||||
|
await authorized_store.insert(table_name, {"id": "6", "data": "different_user_data"})
|
||||||
|
|
||||||
|
# Now test with the multi-user who has both roles=admin and teams=dev
|
||||||
|
mock_get_authenticated_user.return_value = multi_user
|
||||||
|
result = await authorized_store.fetch_all(table_name, policy=default_policy())
|
||||||
|
|
||||||
|
# Should see:
|
||||||
|
# - public record (1) - no access_attributes
|
||||||
|
# - admin record (2) - user matches roles=admin, teams missing (allowed)
|
||||||
|
# - multi_user record (3) - user matches both roles=admin and teams=dev
|
||||||
|
# - roles_only record (4) - user matches roles=admin, teams missing (allowed)
|
||||||
|
# - teams_only record (5) - user matches teams=dev, roles missing (allowed)
|
||||||
|
# Should NOT see:
|
||||||
|
# - different_user record (6) - user doesn't match roles=user or teams=qa
|
||||||
|
expected_ids = {"1", "2", "3", "4", "5"}
|
||||||
|
actual_ids = {record["id"] for record in result.data}
|
||||||
|
assert actual_ids == expected_ids, f"Expected to see records {expected_ids} but got {actual_ids}"
|
||||||
|
|
||||||
|
# Verify the category missing logic specifically
|
||||||
|
# Records 4 and 5 test the "category missing" branch where one attribute category is missing
|
||||||
|
category_test_ids = {record["id"] for record in result.data if record["id"] in ["4", "5"]}
|
||||||
|
assert category_test_ids == {"4", "5"}, (
|
||||||
|
f"Category missing logic failed: expected 4,5 but got {category_test_ids}"
|
||||||
|
)
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# Clean up records
|
||||||
|
for record_id in ["1", "2", "3", "4", "5", "6"]:
|
||||||
|
try:
|
||||||
|
await base_sqlstore.delete(table_name, {"id": record_id})
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# Clean up temporary SQLite database file if needed
|
||||||
|
if cleanup_path:
|
||||||
|
try:
|
||||||
|
os.unlink(cleanup_path)
|
||||||
|
except OSError:
|
||||||
|
pass
|
|
@ -104,19 +104,17 @@ async def test_sql_policy_consistency(mock_get_authenticated_user):
|
||||||
|
|
||||||
# Test scenarios with different access control patterns
|
# Test scenarios with different access control patterns
|
||||||
test_scenarios = [
|
test_scenarios = [
|
||||||
# Scenario 1: Public record (no access control)
|
# Scenario 1: Public record (no access control - represents None user insert)
|
||||||
{"id": "1", "name": "public", "access_attributes": None},
|
{"id": "1", "name": "public", "access_attributes": None},
|
||||||
# Scenario 2: Empty access control (should be treated as public)
|
# Scenario 2: Record with roles requirement
|
||||||
{"id": "2", "name": "empty", "access_attributes": {}},
|
{"id": "2", "name": "admin-only", "access_attributes": {"roles": ["admin"]}},
|
||||||
# Scenario 3: Record with roles requirement
|
# Scenario 3: Record with multiple attribute categories
|
||||||
{"id": "3", "name": "admin-only", "access_attributes": {"roles": ["admin"]}},
|
{"id": "3", "name": "admin-ml-team", "access_attributes": {"roles": ["admin"], "teams": ["ml-team"]}},
|
||||||
# Scenario 4: Record with multiple attribute categories
|
# Scenario 4: Record with teams only (missing roles category)
|
||||||
{"id": "4", "name": "admin-ml-team", "access_attributes": {"roles": ["admin"], "teams": ["ml-team"]}},
|
{"id": "4", "name": "ml-team-only", "access_attributes": {"teams": ["ml-team"]}},
|
||||||
# Scenario 5: Record with teams only (missing roles category)
|
# Scenario 5: Record with roles and projects
|
||||||
{"id": "5", "name": "ml-team-only", "access_attributes": {"teams": ["ml-team"]}},
|
|
||||||
# Scenario 6: Record with roles and projects
|
|
||||||
{
|
{
|
||||||
"id": "6",
|
"id": "5",
|
||||||
"name": "admin-project-x",
|
"name": "admin-project-x",
|
||||||
"access_attributes": {"roles": ["admin"], "projects": ["project-x"]},
|
"access_attributes": {"roles": ["admin"], "projects": ["project-x"]},
|
||||||
},
|
},
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue