fix: authorized sql store with postgres (#2641)
Some checks failed
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 1s
Integration Tests / test-matrix (library, 3.12, datasets) (push) Failing after 4s
Integration Tests / test-matrix (library, 3.12, agents) (push) Failing after 7s
Integration Tests / test-matrix (library, 3.12, scoring) (push) Failing after 5s
Integration Tests / test-matrix (library, 3.13, post_training) (push) Failing after 4s
Integration Tests / test-matrix (library, 3.12, inference) (push) Failing after 10s
Integration Tests / test-matrix (library, 3.12, vector_io) (push) Failing after 9s
Integration Tests / test-matrix (library, 3.13, providers) (push) Failing after 8s
Integration Tests / test-matrix (library, 3.12, tool_runtime) (push) Failing after 6s
Integration Tests / test-matrix (library, 3.12, post_training) (push) Failing after 13s
SqlStore Integration Tests / test-postgres (3.13) (push) Failing after 16s
SqlStore Integration Tests / test-postgres (3.12) (push) Failing after 19s
Integration Tests / test-matrix (library, 3.13, datasets) (push) Failing after 15s
Integration Tests / test-matrix (server, 3.12, datasets) (push) Failing after 8s
Integration Tests / test-matrix (library, 3.12, inspect) (push) Failing after 12s
Integration Tests / test-matrix (server, 3.13, inference) (push) Failing after 9s
Integration Tests / test-matrix (library, 3.13, tool_runtime) (push) Failing after 13s
Integration Tests / test-matrix (library, 3.13, scoring) (push) Failing after 12s
Integration Tests / test-matrix (server, 3.12, providers) (push) Failing after 8s
Integration Tests / test-matrix (library, 3.13, vector_io) (push) Failing after 10s
Integration Tests / test-matrix (server, 3.12, agents) (push) Failing after 12s
Integration Tests / test-matrix (library, 3.13, inspect) (push) Failing after 14s
Integration Tests / test-matrix (server, 3.12, tool_runtime) (push) Failing after 10s
Integration Tests / test-matrix (server, 3.13, datasets) (push) Failing after 8s
Integration Tests / test-matrix (server, 3.13, inspect) (push) Failing after 12s
Integration Tests / test-matrix (server, 3.12, inference) (push) Failing after 6s
Integration Tests / test-matrix (server, 3.13, providers) (push) Failing after 10s
Integration Tests / test-matrix (server, 3.13, scoring) (push) Failing after 9s
Integration Tests / test-matrix (library, 3.13, inference) (push) Failing after 11s
Integration Tests / test-matrix (server, 3.12, inspect) (push) Failing after 12s
Integration Tests / test-matrix (server, 3.12, scoring) (push) Failing after 11s
Integration Tests / test-matrix (library, 3.12, providers) (push) Failing after 13s
Integration Tests / test-matrix (server, 3.13, agents) (push) Failing after 10s
Vector IO Integration Tests / test-matrix (3.12, inline::faiss) (push) Failing after 8s
Integration Tests / test-matrix (server, 3.13, post_training) (push) Failing after 11s
Integration Tests / test-matrix (server, 3.13, vector_io) (push) Failing after 7s
Vector IO Integration Tests / test-matrix (3.12, inline::milvus) (push) Failing after 8s
Integration Tests / test-matrix (library, 3.13, agents) (push) Failing after 13s
Integration Tests / test-matrix (server, 3.12, vector_io) (push) Failing after 14s
Integration Tests / test-matrix (server, 3.12, post_training) (push) Failing after 14s
Integration Tests / test-matrix (server, 3.13, tool_runtime) (push) Failing after 8s
Vector IO Integration Tests / test-matrix (3.12, inline::sqlite-vec) (push) Failing after 25s
Vector IO Integration Tests / test-matrix (3.12, remote::chromadb) (push) Failing after 23s
Vector IO Integration Tests / test-matrix (3.12, remote::pgvector) (push) Failing after 28s
Vector IO Integration Tests / test-matrix (3.13, inline::faiss) (push) Failing after 27s
Vector IO Integration Tests / test-matrix (3.13, inline::milvus) (push) Failing after 12s
Vector IO Integration Tests / test-matrix (3.13, inline::sqlite-vec) (push) Failing after 10s
Vector IO Integration Tests / test-matrix (3.13, remote::chromadb) (push) Failing after 6s
Vector IO Integration Tests / test-matrix (3.13, remote::pgvector) (push) Failing after 5s
Test Llama Stack Build / generate-matrix (push) Successful in 5s
Python Package Build Test / build (3.12) (push) Failing after 1s
Test External Providers / test-external-providers (venv) (push) Failing after 3s
Python Package Build Test / build (3.13) (push) Failing after 3s
Update ReadTheDocs / update-readthedocs (push) Failing after 3s
Test Llama Stack Build / build (push) Failing after 4s
Unit Tests / unit-tests (3.12) (push) Failing after 4s
Unit Tests / unit-tests (3.13) (push) Failing after 7s
Test Llama Stack Build / build-single-provider (push) Failing after 44s
Test Llama Stack Build / build-ubi9-container-distribution (push) Failing after 41s
Test Llama Stack Build / build-custom-container-distribution (push) Failing after 43s
Pre-commit / pre-commit (push) Successful in 1m34s

# What does this PR do?
postgres has different json extract syntax from sqlite

## Test Plan
added integration test
This commit is contained in:
ehhuang 2025-07-07 19:36:34 -07:00 committed by GitHub
parent 5bb3817c49
commit e9926564bd
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 337 additions and 27 deletions

View file

@ -15,6 +15,7 @@ from llama_stack.distribution.request_headers import get_authenticated_user
from llama_stack.log import get_logger
from .api import ColumnDefinition, ColumnType, PaginatedResponse, SqlStore
from .sqlstore import SqlStoreType
logger = get_logger(name=__name__, category="authorized_sqlstore")
@ -71,9 +72,18 @@ class AuthorizedSqlStore:
:param sql_store: Base SqlStore implementation to wrap
"""
self.sql_store = sql_store
self._detect_database_type()
self._validate_sql_optimized_policy()
def _detect_database_type(self) -> None:
"""Detect the database type from the underlying SQL store."""
if not hasattr(self.sql_store, "config"):
raise ValueError("SqlStore must have a config attribute to be used with AuthorizedSqlStore")
self.database_type = self.sql_store.config.type
if self.database_type not in [SqlStoreType.postgres, SqlStoreType.sqlite]:
raise ValueError(f"Unsupported database type: {self.database_type}")
def _validate_sql_optimized_policy(self) -> None:
"""Validate that SQL_OPTIMIZED_POLICY matches the actual default_policy().
@ -181,6 +191,50 @@ class AuthorizedSqlStore:
else:
return self._build_conservative_where_clause()
def _json_extract(self, column: str, path: str) -> str:
"""Extract JSON value (keeping JSON type).
Args:
column: The JSON column name
path: The JSON path (e.g., 'roles', 'teams')
Returns:
SQL expression to extract JSON value
"""
if self.database_type == SqlStoreType.postgres:
return f"{column}->'{path}'"
elif self.database_type == SqlStoreType.sqlite:
return f"JSON_EXTRACT({column}, '$.{path}')"
else:
raise ValueError(f"Unsupported database type: {self.database_type}")
def _json_extract_text(self, column: str, path: str) -> str:
"""Extract JSON value as text.
Args:
column: The JSON column name
path: The JSON path (e.g., 'roles', 'teams')
Returns:
SQL expression to extract JSON value as text
"""
if self.database_type == SqlStoreType.postgres:
return f"{column}->>'{path}'"
elif self.database_type == SqlStoreType.sqlite:
return f"JSON_EXTRACT({column}, '$.{path}')"
else:
raise ValueError(f"Unsupported database type: {self.database_type}")
def _get_public_access_conditions(self) -> list[str]:
"""Get the SQL conditions for public access."""
if self.database_type == SqlStoreType.postgres:
# Postgres stores JSON null as 'null'
return ["access_attributes::text = 'null'"]
elif self.database_type == SqlStoreType.sqlite:
return ["access_attributes = 'null'"]
else:
raise ValueError(f"Unsupported database type: {self.database_type}")
def _build_default_policy_where_clause(self) -> str:
"""Build SQL WHERE clause for the default policy.
@ -189,30 +243,33 @@ class AuthorizedSqlStore:
"""
current_user = get_authenticated_user()
base_conditions = self._get_public_access_conditions()
if not current_user or not current_user.attributes:
return "(access_attributes IS NULL OR access_attributes = 'null' OR access_attributes = '{}')"
# Only allow public records
return f"({' OR '.join(base_conditions)})"
else:
base_conditions = ["access_attributes IS NULL", "access_attributes = 'null'", "access_attributes = '{}'"]
user_attr_conditions = []
for attr_key, user_values in current_user.attributes.items():
if user_values:
value_conditions = []
for value in user_values:
value_conditions.append(f"JSON_EXTRACT(access_attributes, '$.{attr_key}') LIKE '%\"{value}\"%'")
# Check if JSON array contains the value
escaped_value = value.replace("'", "''")
json_text = self._json_extract_text("access_attributes", attr_key)
value_conditions.append(f"({json_text} LIKE '%\"{escaped_value}\"%')")
if value_conditions:
category_missing = f"JSON_EXTRACT(access_attributes, '$.{attr_key}') IS NULL"
# Check if the category is missing (NULL)
category_missing = f"{self._json_extract('access_attributes', attr_key)} IS NULL"
user_matches_category = f"({' OR '.join(value_conditions)})"
user_attr_conditions.append(f"({category_missing} OR {user_matches_category})")
if user_attr_conditions:
all_requirements_met = f"({' AND '.join(user_attr_conditions)})"
base_conditions.append(all_requirements_met)
return f"({' OR '.join(base_conditions)})"
else:
return f"({' OR '.join(base_conditions)})"
return f"({' OR '.join(base_conditions)})"
def _build_conservative_where_clause(self) -> str:
"""Conservative SQL filtering for custom policies.
@ -222,5 +279,8 @@ class AuthorizedSqlStore:
current_user = get_authenticated_user()
if not current_user:
return "(access_attributes IS NULL OR access_attributes = 'null' OR access_attributes = '{}')"
# Only allow public records
base_conditions = self._get_public_access_conditions()
return f"({' OR '.join(base_conditions)})"
return "1=1"

View file

@ -4,9 +4,8 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from abc import abstractmethod
from enum import Enum
from enum import StrEnum
from pathlib import Path
from typing import Annotated, Literal
@ -19,7 +18,7 @@ from .api import SqlStore
sql_store_pip_packages = ["sqlalchemy[asyncio]", "aiosqlite", "asyncpg"]
class SqlStoreType(Enum):
class SqlStoreType(StrEnum):
sqlite = "sqlite"
postgres = "postgres"
@ -36,7 +35,7 @@ class SqlAlchemySqlStoreConfig(BaseModel):
class SqliteSqlStoreConfig(SqlAlchemySqlStoreConfig):
type: Literal["sqlite"] = SqlStoreType.sqlite.value
type: Literal[SqlStoreType.sqlite] = SqlStoreType.sqlite
db_path: str = Field(
default=(RUNTIME_BASE_DIR / "sqlstore.db").as_posix(),
description="Database path, e.g. ~/.llama/distributions/ollama/sqlstore.db",
@ -59,7 +58,7 @@ class SqliteSqlStoreConfig(SqlAlchemySqlStoreConfig):
class PostgresSqlStoreConfig(SqlAlchemySqlStoreConfig):
type: Literal["postgres"] = SqlStoreType.postgres.value
type: Literal[SqlStoreType.postgres] = SqlStoreType.postgres
host: str = "localhost"
port: int = 5432
db: str = "llamastack"
@ -107,7 +106,7 @@ def get_pip_packages(store_config: dict | SqlStoreConfig) -> list[str]:
def sqlstore_impl(config: SqlStoreConfig) -> SqlStore:
if config.type in [SqlStoreType.sqlite.value, SqlStoreType.postgres.value]:
if config.type in [SqlStoreType.sqlite, SqlStoreType.postgres]:
from .sqlalchemy_sqlstore import SqlAlchemySqlStoreImpl
impl = SqlAlchemySqlStoreImpl(config)