chore(package): migrate to src/ layout (#3920)

Migrates package structure to src/ layout following Python packaging
best practices.

All code moved from `llama_stack/` to `src/llama_stack/`. Public API
unchanged - imports remain `import llama_stack.*`.

Updated build configs, pre-commit hooks, scripts, and GitHub workflows
accordingly. All hooks pass, package builds cleanly.

**Developer note**: Reinstall after pulling: `pip install -e .`
This commit is contained in:
Ashwin Bharambe 2025-10-27 12:02:21 -07:00 committed by GitHub
parent 98a5047f9d
commit 471b1b248b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
791 changed files with 2983 additions and 456 deletions

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,128 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from collections.abc import Mapping, Sequence
from enum import Enum
from typing import Any, Literal, Protocol
from pydantic import BaseModel
from llama_stack.apis.common.responses import PaginatedResponse
class ColumnType(Enum):
INTEGER = "INTEGER"
STRING = "STRING"
TEXT = "TEXT"
FLOAT = "FLOAT"
BOOLEAN = "BOOLEAN"
JSON = "JSON"
DATETIME = "DATETIME"
class ColumnDefinition(BaseModel):
type: ColumnType
primary_key: bool = False
nullable: bool = True
default: Any = None
class SqlStore(Protocol):
"""
A protocol for a SQL store.
"""
async def create_table(self, table: str, schema: Mapping[str, ColumnType | ColumnDefinition]) -> None:
"""
Create a table.
"""
pass
async def insert(self, table: str, data: Mapping[str, Any] | Sequence[Mapping[str, Any]]) -> None:
"""
Insert a row or batch of rows into a table.
"""
pass
async def fetch_all(
self,
table: str,
where: Mapping[str, Any] | None = None,
where_sql: str | None = None,
limit: int | None = None,
order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None,
cursor: tuple[str, str] | None = None,
) -> PaginatedResponse:
"""
Fetch all rows from a table with optional cursor-based pagination.
:param table: The table name
:param where: Simple key-value WHERE conditions
:param where_sql: Raw SQL WHERE clause for complex queries
:param limit: Maximum number of records to return
:param order_by: List of (column, order) tuples for sorting
:param cursor: Tuple of (key_column, cursor_id) for pagination (None for first page)
Requires order_by with exactly one column when used
:return: PaginatedResult with data and has_more flag
Note: Cursor pagination only supports single-column ordering for simplicity.
Multi-column ordering is allowed without cursor but will raise an error with cursor.
"""
pass
async def fetch_one(
self,
table: str,
where: Mapping[str, Any] | None = None,
where_sql: str | None = None,
order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None,
) -> dict[str, Any] | None:
"""
Fetch one row from a table.
"""
pass
async def update(
self,
table: str,
data: Mapping[str, Any],
where: Mapping[str, Any],
) -> None:
"""
Update a row in a table.
"""
pass
async def delete(
self,
table: str,
where: Mapping[str, Any],
) -> None:
"""
Delete a row from a table.
"""
pass
async def add_column_if_not_exists(
self,
table: str,
column_name: str,
column_type: ColumnType,
nullable: bool = True,
) -> None:
"""
Add a column to an existing table if the column doesn't already exist.
This is useful for table migrations when adding new functionality.
If the table doesn't exist, this method should do nothing.
If the column already exists, this method should do nothing.
:param table: Table name
:param column_name: Name of the column to add
:param column_type: Type of the column to add
:param nullable: Whether the column should be nullable (default: True)
"""
pass

View file

@ -0,0 +1,303 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from collections.abc import Mapping, Sequence
from typing import Any, Literal
from llama_stack.core.access_control.access_control import default_policy, is_action_allowed
from llama_stack.core.access_control.conditions import ProtectedResource
from llama_stack.core.access_control.datatypes import AccessRule, Action, Scope
from llama_stack.core.datatypes import User
from llama_stack.core.request_headers import get_authenticated_user
from llama_stack.core.storage.datatypes import StorageBackendType
from llama_stack.log import get_logger
from .api import ColumnDefinition, ColumnType, PaginatedResponse, SqlStore
logger = get_logger(name=__name__, category="providers::utils")
# Hardcoded copy of the default policy that our SQL filtering implements
# WARNING: If default_policy() changes, this constant must be updated accordingly
# or SQL filtering will fall back to conservative mode (safe but less performant)
#
# This policy represents: "Permit all actions when user is in owners list for ALL attribute categories"
# The corresponding SQL logic is implemented in _build_default_policy_where_clause():
# - Public records (no access_attributes) are always accessible
# - Records with access_attributes require user to match ALL categories that exist in the resource
# - Missing categories in the resource are treated as "no restriction" (allow)
# - Within each category, user needs ANY matching value (OR logic)
# - Between categories, user needs ALL categories to match (AND logic)
SQL_OPTIMIZED_POLICY = [
AccessRule(
permit=Scope(actions=list(Action)),
when=["user in owners roles", "user in owners teams", "user in owners projects", "user in owners namespaces"],
),
]
def _enhance_item_with_access_control(item: Mapping[str, Any], current_user: User | None) -> Mapping[str, Any]:
"""Add access control attributes to a data item."""
enhanced = dict(item)
if current_user:
enhanced["owner_principal"] = current_user.principal
enhanced["access_attributes"] = current_user.attributes
else:
enhanced["owner_principal"] = None
enhanced["access_attributes"] = None
return enhanced
class SqlRecord(ProtectedResource):
def __init__(self, record_id: str, table_name: str, owner: User):
self.type = f"sql_record::{table_name}"
self.identifier = record_id
self.owner = owner
class AuthorizedSqlStore:
"""
Authorization layer for SqlStore that provides access control functionality.
This class composes a base SqlStore and adds authorization methods that handle
access control policies, user attribute capture, and SQL filtering optimization.
"""
def __init__(self, sql_store: SqlStore, policy: list[AccessRule]):
"""
Initialize the authorization layer.
:param sql_store: Base SqlStore implementation to wrap
:param policy: Access control policy to use for authorization
"""
self.sql_store = sql_store
self.policy = policy
self._detect_database_type()
self._validate_sql_optimized_policy()
def _detect_database_type(self) -> None:
"""Detect the database type from the underlying SQL store."""
if not hasattr(self.sql_store, "config"):
raise ValueError("SqlStore must have a config attribute to be used with AuthorizedSqlStore")
self.database_type = self.sql_store.config.type.value
if self.database_type not in [StorageBackendType.SQL_POSTGRES.value, StorageBackendType.SQL_SQLITE.value]:
raise ValueError(f"Unsupported database type: {self.database_type}")
def _validate_sql_optimized_policy(self) -> None:
"""Validate that SQL_OPTIMIZED_POLICY matches the actual default_policy().
This ensures that if default_policy() changes, we detect the mismatch and
can update our SQL filtering logic accordingly.
"""
actual_default = default_policy()
if SQL_OPTIMIZED_POLICY != actual_default:
logger.warning(
f"SQL_OPTIMIZED_POLICY does not match default_policy(). "
f"SQL filtering will use conservative mode. "
f"Expected: {SQL_OPTIMIZED_POLICY}, Got: {actual_default}",
)
async def create_table(self, table: str, schema: Mapping[str, ColumnType | ColumnDefinition]) -> None:
"""Create a table with built-in access control support."""
enhanced_schema = dict(schema)
if "access_attributes" not in enhanced_schema:
enhanced_schema["access_attributes"] = ColumnType.JSON
if "owner_principal" not in enhanced_schema:
enhanced_schema["owner_principal"] = ColumnType.STRING
await self.sql_store.create_table(table, enhanced_schema)
await self.sql_store.add_column_if_not_exists(table, "access_attributes", ColumnType.JSON)
await self.sql_store.add_column_if_not_exists(table, "owner_principal", ColumnType.STRING)
async def insert(self, table: str, data: Mapping[str, Any] | Sequence[Mapping[str, Any]]) -> None:
"""Insert a row or batch of rows with automatic access control attribute capture."""
current_user = get_authenticated_user()
enhanced_data: Mapping[str, Any] | Sequence[Mapping[str, Any]]
if isinstance(data, Mapping):
enhanced_data = _enhance_item_with_access_control(data, current_user)
else:
enhanced_data = [_enhance_item_with_access_control(item, current_user) for item in data]
await self.sql_store.insert(table, enhanced_data)
async def fetch_all(
self,
table: str,
where: Mapping[str, Any] | None = None,
limit: int | None = None,
order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None,
cursor: tuple[str, str] | None = None,
) -> PaginatedResponse:
"""Fetch all rows with automatic access control filtering."""
access_where = self._build_access_control_where_clause(self.policy)
rows = await self.sql_store.fetch_all(
table=table,
where=where,
where_sql=access_where,
limit=limit,
order_by=order_by,
cursor=cursor,
)
current_user = get_authenticated_user()
filtered_rows = []
for row in rows.data:
stored_access_attrs = row.get("access_attributes")
stored_owner_principal = row.get("owner_principal") or ""
record_id = row.get("id", "unknown")
sql_record = SqlRecord(
str(record_id), table, User(principal=stored_owner_principal, attributes=stored_access_attrs)
)
if is_action_allowed(self.policy, Action.READ, sql_record, current_user):
filtered_rows.append(row)
return PaginatedResponse(
data=filtered_rows,
has_more=rows.has_more,
)
async def fetch_one(
self,
table: str,
where: Mapping[str, Any] | None = None,
order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None,
) -> dict[str, Any] | None:
"""Fetch one row with automatic access control checking."""
results = await self.fetch_all(
table=table,
where=where,
limit=1,
order_by=order_by,
)
return results.data[0] if results.data else None
async def update(self, table: str, data: Mapping[str, Any], where: Mapping[str, Any]) -> None:
"""Update rows with automatic access control attribute capture."""
enhanced_data = dict(data)
current_user = get_authenticated_user()
if current_user:
enhanced_data["owner_principal"] = current_user.principal
enhanced_data["access_attributes"] = current_user.attributes
else:
enhanced_data["owner_principal"] = None
enhanced_data["access_attributes"] = None
await self.sql_store.update(table, enhanced_data, where)
async def delete(self, table: str, where: Mapping[str, Any]) -> None:
"""Delete rows with automatic access control filtering."""
await self.sql_store.delete(table, where)
def _build_access_control_where_clause(self, policy: list[AccessRule]) -> str:
"""Build SQL WHERE clause for access control filtering.
Only applies SQL filtering for the default policy to ensure correctness.
For custom policies, uses conservative filtering to avoid blocking legitimate access.
"""
current_user = get_authenticated_user()
if not policy or policy == SQL_OPTIMIZED_POLICY:
return self._build_default_policy_where_clause(current_user)
else:
return self._build_conservative_where_clause()
def _json_extract(self, column: str, path: str) -> str:
"""Extract JSON value (keeping JSON type).
Args:
column: The JSON column name
path: The JSON path (e.g., 'roles', 'teams')
Returns:
SQL expression to extract JSON value
"""
if self.database_type == StorageBackendType.SQL_POSTGRES.value:
return f"{column}->'{path}'"
elif self.database_type == StorageBackendType.SQL_SQLITE.value:
return f"JSON_EXTRACT({column}, '$.{path}')"
else:
raise ValueError(f"Unsupported database type: {self.database_type}")
def _json_extract_text(self, column: str, path: str) -> str:
"""Extract JSON value as text.
Args:
column: The JSON column name
path: The JSON path (e.g., 'roles', 'teams')
Returns:
SQL expression to extract JSON value as text
"""
if self.database_type == StorageBackendType.SQL_POSTGRES.value:
return f"{column}->>'{path}'"
elif self.database_type == StorageBackendType.SQL_SQLITE.value:
return f"JSON_EXTRACT({column}, '$.{path}')"
else:
raise ValueError(f"Unsupported database type: {self.database_type}")
def _get_public_access_conditions(self) -> list[str]:
"""Get the SQL conditions for public access."""
# Public records are records that have no owner_principal or access_attributes
conditions = ["owner_principal = ''"]
if self.database_type == StorageBackendType.SQL_POSTGRES.value:
# Postgres stores JSON null as 'null'
conditions.append("access_attributes::text = 'null'")
elif self.database_type == StorageBackendType.SQL_SQLITE.value:
conditions.append("access_attributes = 'null'")
else:
raise ValueError(f"Unsupported database type: {self.database_type}")
return conditions
def _build_default_policy_where_clause(self, current_user: User | None) -> str:
"""Build SQL WHERE clause for the default policy.
Default policy: permit all actions when user in owners [roles, teams, projects, namespaces]
This means user must match ALL attribute categories that exist in the resource.
"""
base_conditions = self._get_public_access_conditions()
user_attr_conditions = []
if current_user and current_user.attributes:
for attr_key, user_values in current_user.attributes.items():
if user_values:
value_conditions = []
for value in user_values:
# Check if JSON array contains the value
escaped_value = value.replace("'", "''")
json_text = self._json_extract_text("access_attributes", attr_key)
value_conditions.append(f"({json_text} LIKE '%\"{escaped_value}\"%')")
if value_conditions:
# Check if the category is missing (NULL)
category_missing = f"{self._json_extract('access_attributes', attr_key)} IS NULL"
user_matches_category = f"({' OR '.join(value_conditions)})"
user_attr_conditions.append(f"({category_missing} OR {user_matches_category})")
if user_attr_conditions:
all_requirements_met = f"({' AND '.join(user_attr_conditions)})"
base_conditions.append(all_requirements_met)
return f"({' OR '.join(base_conditions)})"
def _build_conservative_where_clause(self) -> str:
"""Conservative SQL filtering for custom policies.
Only filters records we're 100% certain would be denied by any reasonable policy.
"""
current_user = get_authenticated_user()
if not current_user:
# Only allow public records
base_conditions = self._get_public_access_conditions()
return f"({' OR '.join(base_conditions)})"
return "1=1"

View file

@ -0,0 +1,313 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from collections.abc import Mapping, Sequence
from typing import Any, Literal
from sqlalchemy import (
JSON,
Boolean,
Column,
DateTime,
Float,
Integer,
MetaData,
String,
Table,
Text,
inspect,
select,
text,
)
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
from sqlalchemy.ext.asyncio.engine import AsyncEngine
from sqlalchemy.sql.elements import ColumnElement
from llama_stack.apis.common.responses import PaginatedResponse
from llama_stack.core.storage.datatypes import SqlAlchemySqlStoreConfig
from llama_stack.log import get_logger
from .api import ColumnDefinition, ColumnType, SqlStore
logger = get_logger(name=__name__, category="providers::utils")
TYPE_MAPPING: dict[ColumnType, Any] = {
ColumnType.INTEGER: Integer,
ColumnType.STRING: String,
ColumnType.FLOAT: Float,
ColumnType.BOOLEAN: Boolean,
ColumnType.DATETIME: DateTime,
ColumnType.TEXT: Text,
ColumnType.JSON: JSON,
}
def _build_where_expr(column: ColumnElement, value: Any) -> ColumnElement:
"""Return a SQLAlchemy expression for a where condition.
`value` may be a simple scalar (equality) or a mapping like {">": 123}.
The returned expression is a SQLAlchemy ColumnElement usable in query.where(...).
"""
if isinstance(value, Mapping):
if len(value) != 1:
raise ValueError(f"Operator mapping must have a single operator, got: {value}")
op, operand = next(iter(value.items()))
if op == "==" or op == "=":
return column == operand
if op == ">":
return column > operand
if op == "<":
return column < operand
if op == ">=":
return column >= operand
if op == "<=":
return column <= operand
raise ValueError(f"Unsupported operator '{op}' in where mapping")
return column == value
class SqlAlchemySqlStoreImpl(SqlStore):
def __init__(self, config: SqlAlchemySqlStoreConfig):
self.config = config
self.async_session = async_sessionmaker(self.create_engine())
self.metadata = MetaData()
def create_engine(self) -> AsyncEngine:
return create_async_engine(self.config.engine_str, pool_pre_ping=True)
async def create_table(
self,
table: str,
schema: Mapping[str, ColumnType | ColumnDefinition],
) -> None:
if not schema:
raise ValueError(f"No columns defined for table '{table}'.")
sqlalchemy_columns: list[Column] = []
for col_name, col_props in schema.items():
col_type = None
is_primary_key = False
is_nullable = True
if isinstance(col_props, ColumnType):
col_type = col_props
elif isinstance(col_props, ColumnDefinition):
col_type = col_props.type
is_primary_key = col_props.primary_key
is_nullable = col_props.nullable
sqlalchemy_type = TYPE_MAPPING.get(col_type)
if not sqlalchemy_type:
raise ValueError(f"Unsupported column type '{col_type}' for column '{col_name}'.")
sqlalchemy_columns.append(
Column(col_name, sqlalchemy_type, primary_key=is_primary_key, nullable=is_nullable)
)
if table not in self.metadata.tables:
sqlalchemy_table = Table(table, self.metadata, *sqlalchemy_columns)
else:
sqlalchemy_table = self.metadata.tables[table]
engine = self.create_engine()
async with engine.begin() as conn:
await conn.run_sync(self.metadata.create_all, tables=[sqlalchemy_table], checkfirst=True)
async def insert(self, table: str, data: Mapping[str, Any] | Sequence[Mapping[str, Any]]) -> None:
async with self.async_session() as session:
await session.execute(self.metadata.tables[table].insert(), data)
await session.commit()
async def fetch_all(
self,
table: str,
where: Mapping[str, Any] | None = None,
where_sql: str | None = None,
limit: int | None = None,
order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None,
cursor: tuple[str, str] | None = None,
) -> PaginatedResponse:
async with self.async_session() as session:
table_obj = self.metadata.tables[table]
query = select(table_obj)
if where:
for key, value in where.items():
query = query.where(_build_where_expr(table_obj.c[key], value))
if where_sql:
query = query.where(text(where_sql))
# Handle cursor-based pagination
if cursor:
# Validate cursor tuple format
if not isinstance(cursor, tuple) or len(cursor) != 2:
raise ValueError(f"Cursor must be a tuple of (key_column, cursor_id), got: {cursor}")
# Require order_by for cursor pagination
if not order_by:
raise ValueError("order_by is required when using cursor pagination")
# Only support single-column ordering for cursor pagination
if len(order_by) != 1:
raise ValueError(
f"Cursor pagination only supports single-column ordering, got {len(order_by)} columns"
)
cursor_key_column, cursor_id = cursor
order_column, order_direction = order_by[0]
# Verify cursor_key_column exists
if cursor_key_column not in table_obj.c:
raise ValueError(f"Cursor key column '{cursor_key_column}' not found in table '{table}'")
# Get cursor value for the order column
cursor_query = select(table_obj.c[order_column]).where(table_obj.c[cursor_key_column] == cursor_id)
cursor_result = await session.execute(cursor_query)
cursor_row = cursor_result.fetchone()
if not cursor_row:
raise ValueError(f"Record with {cursor_key_column}='{cursor_id}' not found in table '{table}'")
cursor_value = cursor_row[0]
# Apply cursor condition based on sort direction
if order_direction == "desc":
query = query.where(table_obj.c[order_column] < cursor_value)
else:
query = query.where(table_obj.c[order_column] > cursor_value)
# Apply ordering
if order_by:
if not isinstance(order_by, list):
raise ValueError(
f"order_by must be a list of tuples (column, order={['asc', 'desc']}), got {order_by}"
)
for order in order_by:
if not isinstance(order, tuple):
raise ValueError(
f"order_by must be a list of tuples (column, order={['asc', 'desc']}), got {order_by}"
)
name, order_type = order
if name not in table_obj.c:
raise ValueError(f"Column '{name}' not found in table '{table}'")
if order_type == "asc":
query = query.order_by(table_obj.c[name].asc())
elif order_type == "desc":
query = query.order_by(table_obj.c[name].desc())
else:
raise ValueError(f"Invalid order '{order_type}' for column '{name}'")
# Fetch limit + 1 to determine has_more
fetch_limit = limit
if limit:
fetch_limit = limit + 1
if fetch_limit:
query = query.limit(fetch_limit)
result = await session.execute(query)
if result.rowcount == 0:
rows = []
else:
rows = [dict(row._mapping) for row in result]
# Always return pagination result
has_more = False
if limit and len(rows) > limit:
has_more = True
rows = rows[:limit]
return PaginatedResponse(data=rows, has_more=has_more)
async def fetch_one(
self,
table: str,
where: Mapping[str, Any] | None = None,
where_sql: str | None = None,
order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None,
) -> dict[str, Any] | None:
result = await self.fetch_all(table, where, where_sql, limit=1, order_by=order_by)
if not result.data:
return None
return result.data[0]
async def update(
self,
table: str,
data: Mapping[str, Any],
where: Mapping[str, Any],
) -> None:
if not where:
raise ValueError("where is required for update")
async with self.async_session() as session:
stmt = self.metadata.tables[table].update()
for key, value in where.items():
stmt = stmt.where(_build_where_expr(self.metadata.tables[table].c[key], value))
await session.execute(stmt, data)
await session.commit()
async def delete(self, table: str, where: Mapping[str, Any]) -> None:
if not where:
raise ValueError("where is required for delete")
async with self.async_session() as session:
stmt = self.metadata.tables[table].delete()
for key, value in where.items():
stmt = stmt.where(_build_where_expr(self.metadata.tables[table].c[key], value))
await session.execute(stmt)
await session.commit()
async def add_column_if_not_exists(
self,
table: str,
column_name: str,
column_type: ColumnType,
nullable: bool = True,
) -> None:
"""Add a column to an existing table if the column doesn't already exist."""
engine = self.create_engine()
try:
async with engine.begin() as conn:
def check_column_exists(sync_conn):
inspector = inspect(sync_conn)
table_names = inspector.get_table_names()
if table not in table_names:
return False, False # table doesn't exist, column doesn't exist
existing_columns = inspector.get_columns(table)
column_names = [col["name"] for col in existing_columns]
return True, column_name in column_names # table exists, column exists or not
table_exists, column_exists = await conn.run_sync(check_column_exists)
if not table_exists or column_exists:
return
sqlalchemy_type = TYPE_MAPPING.get(column_type)
if not sqlalchemy_type:
raise ValueError(f"Unsupported column type '{column_type}' for column '{column_name}'.")
# Create the ALTER TABLE statement
# Note: We need to get the dialect-specific type name
dialect = engine.dialect
type_impl = sqlalchemy_type()
compiled_type = type_impl.compile(dialect=dialect)
nullable_clause = "" if nullable else " NOT NULL"
add_column_sql = text(f"ALTER TABLE {table} ADD COLUMN {column_name} {compiled_type}{nullable_clause}")
await conn.execute(add_column_sql)
except Exception as e:
# If any error occurs during migration, log it but don't fail
# The table creation will handle adding the column
logger.error(f"Error adding column {column_name} to table {table}: {e}")
pass

View file

@ -0,0 +1,70 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Annotated, cast
from pydantic import Field
from llama_stack.core.storage.datatypes import (
PostgresSqlStoreConfig,
SqliteSqlStoreConfig,
SqlStoreReference,
StorageBackendConfig,
StorageBackendType,
)
from .api import SqlStore
sql_store_pip_packages = ["sqlalchemy[asyncio]", "aiosqlite", "asyncpg"]
_SQLSTORE_BACKENDS: dict[str, StorageBackendConfig] = {}
SqlStoreConfig = Annotated[
SqliteSqlStoreConfig | PostgresSqlStoreConfig,
Field(discriminator="type"),
]
def get_pip_packages(store_config: dict | SqlStoreConfig) -> list[str]:
"""Get pip packages for SQL store config, handling both dict and object cases."""
if isinstance(store_config, dict):
store_type = store_config.get("type")
if store_type == StorageBackendType.SQL_SQLITE.value:
return SqliteSqlStoreConfig.pip_packages()
elif store_type == StorageBackendType.SQL_POSTGRES.value:
return PostgresSqlStoreConfig.pip_packages()
else:
raise ValueError(f"Unknown SQL store type: {store_type}")
else:
return store_config.pip_packages()
def sqlstore_impl(reference: SqlStoreReference) -> SqlStore:
backend_name = reference.backend
backend_config = _SQLSTORE_BACKENDS.get(backend_name)
if backend_config is None:
raise ValueError(
f"Unknown SQL store backend '{backend_name}'. Registered backends: {sorted(_SQLSTORE_BACKENDS)}"
)
if isinstance(backend_config, SqliteSqlStoreConfig | PostgresSqlStoreConfig):
from .sqlalchemy_sqlstore import SqlAlchemySqlStoreImpl
config = cast(SqliteSqlStoreConfig | PostgresSqlStoreConfig, backend_config).model_copy()
return SqlAlchemySqlStoreImpl(config)
else:
raise ValueError(f"Unknown sqlstore type {backend_config.type}")
def register_sqlstore_backends(backends: dict[str, StorageBackendConfig]) -> None:
"""Register the set of available SQL store backends for reference resolution."""
global _SQLSTORE_BACKENDS
_SQLSTORE_BACKENDS.clear()
for name, cfg in backends.items():
_SQLSTORE_BACKENDS[name] = cfg