mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 09:53:45 +00:00
refactor(storage): make { kvstore, sqlstore } as llama stack "internal" APIs (#4181)
Some checks failed
SqlStore Integration Tests / test-postgres (3.12) (push) Failing after 1s
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 1s
Integration Tests (Replay) / generate-matrix (push) Successful in 5s
SqlStore Integration Tests / test-postgres (3.13) (push) Failing after 6s
Test Llama Stack Build / generate-matrix (push) Successful in 3s
Test External Providers Installed via Module / test-external-providers-from-module (venv) (push) Has been skipped
Test llama stack list-deps / generate-matrix (push) Successful in 3s
Python Package Build Test / build (3.13) (push) Failing after 3s
API Conformance Tests / check-schema-compatibility (push) Successful in 13s
Python Package Build Test / build (3.12) (push) Failing after 7s
Test llama stack list-deps / show-single-provider (push) Successful in 28s
Test llama stack list-deps / list-deps-from-config (push) Successful in 33s
Test External API and Providers / test-external (venv) (push) Failing after 33s
Vector IO Integration Tests / test-matrix (push) Failing after 43s
Test llama stack list-deps / list-deps (push) Failing after 34s
Test Llama Stack Build / build-single-provider (push) Successful in 46s
Test Llama Stack Build / build (push) Successful in 55s
UI Tests / ui-tests (22) (push) Successful in 1m17s
Test Llama Stack Build / build-ubi9-container-distribution (push) Successful in 1m37s
Unit Tests / unit-tests (3.12) (push) Failing after 1m32s
Unit Tests / unit-tests (3.13) (push) Failing after 2m12s
Test Llama Stack Build / build-custom-container-distribution (push) Successful in 2m21s
Integration Tests (Replay) / Integration Tests (, , , client=, ) (push) Failing after 2m46s
Pre-commit / pre-commit (push) Successful in 3m7s
Some checks failed
SqlStore Integration Tests / test-postgres (3.12) (push) Failing after 1s
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 1s
Integration Tests (Replay) / generate-matrix (push) Successful in 5s
SqlStore Integration Tests / test-postgres (3.13) (push) Failing after 6s
Test Llama Stack Build / generate-matrix (push) Successful in 3s
Test External Providers Installed via Module / test-external-providers-from-module (venv) (push) Has been skipped
Test llama stack list-deps / generate-matrix (push) Successful in 3s
Python Package Build Test / build (3.13) (push) Failing after 3s
API Conformance Tests / check-schema-compatibility (push) Successful in 13s
Python Package Build Test / build (3.12) (push) Failing after 7s
Test llama stack list-deps / show-single-provider (push) Successful in 28s
Test llama stack list-deps / list-deps-from-config (push) Successful in 33s
Test External API and Providers / test-external (venv) (push) Failing after 33s
Vector IO Integration Tests / test-matrix (push) Failing after 43s
Test llama stack list-deps / list-deps (push) Failing after 34s
Test Llama Stack Build / build-single-provider (push) Successful in 46s
Test Llama Stack Build / build (push) Successful in 55s
UI Tests / ui-tests (22) (push) Successful in 1m17s
Test Llama Stack Build / build-ubi9-container-distribution (push) Successful in 1m37s
Unit Tests / unit-tests (3.12) (push) Failing after 1m32s
Unit Tests / unit-tests (3.13) (push) Failing after 2m12s
Test Llama Stack Build / build-custom-container-distribution (push) Successful in 2m21s
Integration Tests (Replay) / Integration Tests (, , , client=, ) (push) Failing after 2m46s
Pre-commit / pre-commit (push) Successful in 3m7s
These primitives (used both by the Stack as well as provider implementations) can be thought of fruitfully as internal-only APIs which can themselves have multiple implementations. We use the new `llama_stack_api.internal` namespace for this. In addition: the change moves kv/sql store impls, configs, and dependency helpers under `core/storage` ## Testing `pytest tests/unit/utils/test_authorized_sqlstore.py`, other existing CI
This commit is contained in:
parent
a3580e6bc0
commit
bd5ad2963e
68 changed files with 302 additions and 309 deletions
|
|
@ -11,10 +11,9 @@ from typing import Any, Literal
|
|||
from pydantic import BaseModel, TypeAdapter
|
||||
|
||||
from llama_stack.core.datatypes import AccessRule, StackRunConfig
|
||||
from llama_stack.core.storage.sqlstore.authorized_sqlstore import AuthorizedSqlStore
|
||||
from llama_stack.core.storage.sqlstore.sqlstore import sqlstore_impl
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType
|
||||
from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import sqlstore_impl
|
||||
from llama_stack_api import (
|
||||
Conversation,
|
||||
ConversationDeletedResource,
|
||||
|
|
@ -25,6 +24,7 @@ from llama_stack_api import (
|
|||
Conversations,
|
||||
Metadata,
|
||||
)
|
||||
from llama_stack_api.internal.sqlstore import ColumnDefinition, ColumnType
|
||||
|
||||
logger = get_logger(name=__name__, category="openai_conversations")
|
||||
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ from typing import Any
|
|||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.core.datatypes import StackRunConfig
|
||||
from llama_stack.providers.utils.kvstore import KVStore, kvstore_impl
|
||||
from llama_stack.core.storage.kvstore import KVStore, kvstore_impl
|
||||
from llama_stack_api import ListPromptsResponse, Prompt, Prompts
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -11,9 +11,9 @@ from datetime import UTC, datetime, timedelta
|
|||
from starlette.types import ASGIApp, Receive, Scope, Send
|
||||
|
||||
from llama_stack.core.storage.datatypes import KVStoreReference, StorageBackendType
|
||||
from llama_stack.core.storage.kvstore.kvstore import _KVSTORE_BACKENDS, kvstore_impl
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.kvstore.api import KVStore
|
||||
from llama_stack.providers.utils.kvstore.kvstore import _KVSTORE_BACKENDS, kvstore_impl
|
||||
from llama_stack_api.internal.kvstore import KVStore
|
||||
|
||||
logger = get_logger(name=__name__, category="core::server")
|
||||
|
||||
|
|
|
|||
|
|
@ -385,8 +385,8 @@ def _initialize_storage(run_config: StackRunConfig):
|
|||
else:
|
||||
raise ValueError(f"Unknown storage backend type: {type}")
|
||||
|
||||
from llama_stack.providers.utils.kvstore.kvstore import register_kvstore_backends
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import register_sqlstore_backends
|
||||
from llama_stack.core.storage.kvstore.kvstore import register_kvstore_backends
|
||||
from llama_stack.core.storage.sqlstore.sqlstore import register_sqlstore_backends
|
||||
|
||||
register_kvstore_backends(kv_backends)
|
||||
register_sqlstore_backends(sql_backends)
|
||||
|
|
|
|||
9
src/llama_stack/core/storage/kvstore/__init__.py
Normal file
9
src/llama_stack/core/storage/kvstore/__init__.py
Normal file
|
|
@ -0,0 +1,9 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack_api.internal.kvstore import KVStore as KVStore
|
||||
|
||||
from .kvstore import * # noqa: F401, F403
|
||||
39
src/llama_stack/core/storage/kvstore/config.py
Normal file
39
src/llama_stack/core/storage/kvstore/config.py
Normal file
|
|
@ -0,0 +1,39 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Annotated
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from llama_stack.core.storage.datatypes import (
|
||||
MongoDBKVStoreConfig,
|
||||
PostgresKVStoreConfig,
|
||||
RedisKVStoreConfig,
|
||||
SqliteKVStoreConfig,
|
||||
StorageBackendType,
|
||||
)
|
||||
|
||||
KVStoreConfig = Annotated[
|
||||
RedisKVStoreConfig | SqliteKVStoreConfig | PostgresKVStoreConfig | MongoDBKVStoreConfig, Field(discriminator="type")
|
||||
]
|
||||
|
||||
|
||||
def get_pip_packages(store_config: dict | KVStoreConfig) -> list[str]:
|
||||
"""Get pip packages for KV store config, handling both dict and object cases."""
|
||||
if isinstance(store_config, dict):
|
||||
store_type = store_config.get("type")
|
||||
if store_type == StorageBackendType.KV_SQLITE.value:
|
||||
return SqliteKVStoreConfig.pip_packages()
|
||||
elif store_type == StorageBackendType.KV_POSTGRES.value:
|
||||
return PostgresKVStoreConfig.pip_packages()
|
||||
elif store_type == StorageBackendType.KV_REDIS.value:
|
||||
return RedisKVStoreConfig.pip_packages()
|
||||
elif store_type == StorageBackendType.KV_MONGODB.value:
|
||||
return MongoDBKVStoreConfig.pip_packages()
|
||||
else:
|
||||
raise ValueError(f"Unknown KV store type: {store_type}")
|
||||
else:
|
||||
return store_config.pip_packages()
|
||||
128
src/llama_stack/core/storage/kvstore/kvstore.py
Normal file
128
src/llama_stack/core/storage/kvstore/kvstore.py
Normal file
|
|
@ -0,0 +1,128 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from llama_stack.core.storage.datatypes import KVStoreReference, StorageBackendConfig
|
||||
from llama_stack_api.internal.kvstore import KVStore
|
||||
|
||||
from .config import (
|
||||
KVStoreConfig,
|
||||
MongoDBKVStoreConfig,
|
||||
PostgresKVStoreConfig,
|
||||
RedisKVStoreConfig,
|
||||
SqliteKVStoreConfig,
|
||||
)
|
||||
|
||||
|
||||
def kvstore_dependencies():
|
||||
"""
|
||||
Returns all possible kvstore dependencies for registry/provider specifications.
|
||||
|
||||
NOTE: For specific kvstore implementations, use config.pip_packages instead.
|
||||
This function returns the union of all dependencies for cases where the specific
|
||||
kvstore type is not known at declaration time (e.g., provider registries).
|
||||
"""
|
||||
return ["aiosqlite", "psycopg2-binary", "redis", "pymongo"]
|
||||
|
||||
|
||||
class InmemoryKVStoreImpl(KVStore):
|
||||
def __init__(self):
|
||||
self._store: dict[str, str] = {}
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def get(self, key: str) -> str | None:
|
||||
return self._store.get(key)
|
||||
|
||||
async def set(self, key: str, value: str, expiration: datetime | None = None) -> None:
|
||||
self._store[key] = value
|
||||
|
||||
async def values_in_range(self, start_key: str, end_key: str) -> list[str]:
|
||||
return [self._store[key] for key in self._store.keys() if key >= start_key and key < end_key]
|
||||
|
||||
async def keys_in_range(self, start_key: str, end_key: str) -> list[str]:
|
||||
"""Get all keys in the given range."""
|
||||
return [key for key in self._store.keys() if key >= start_key and key < end_key]
|
||||
|
||||
async def delete(self, key: str) -> None:
|
||||
del self._store[key]
|
||||
|
||||
|
||||
_KVSTORE_BACKENDS: dict[str, KVStoreConfig] = {}
|
||||
_KVSTORE_INSTANCES: dict[tuple[str, str], KVStore] = {}
|
||||
_KVSTORE_LOCKS: defaultdict[tuple[str, str], asyncio.Lock] = defaultdict(asyncio.Lock)
|
||||
|
||||
|
||||
def register_kvstore_backends(backends: dict[str, StorageBackendConfig]) -> None:
|
||||
"""Register the set of available KV store backends for reference resolution."""
|
||||
global _KVSTORE_BACKENDS
|
||||
global _KVSTORE_INSTANCES
|
||||
global _KVSTORE_LOCKS
|
||||
|
||||
_KVSTORE_BACKENDS.clear()
|
||||
_KVSTORE_INSTANCES.clear()
|
||||
_KVSTORE_LOCKS.clear()
|
||||
for name, cfg in backends.items():
|
||||
typed_cfg = cast(KVStoreConfig, cfg)
|
||||
_KVSTORE_BACKENDS[name] = typed_cfg
|
||||
|
||||
|
||||
async def kvstore_impl(reference: KVStoreReference) -> KVStore:
|
||||
backend_name = reference.backend
|
||||
cache_key = (backend_name, reference.namespace)
|
||||
|
||||
existing = _KVSTORE_INSTANCES.get(cache_key)
|
||||
if existing:
|
||||
return existing
|
||||
|
||||
backend_config = _KVSTORE_BACKENDS.get(backend_name)
|
||||
if backend_config is None:
|
||||
raise ValueError(f"Unknown KVStore backend '{backend_name}'. Registered backends: {sorted(_KVSTORE_BACKENDS)}")
|
||||
|
||||
lock = _KVSTORE_LOCKS[cache_key]
|
||||
async with lock:
|
||||
existing = _KVSTORE_INSTANCES.get(cache_key)
|
||||
if existing:
|
||||
return existing
|
||||
|
||||
config = backend_config.model_copy()
|
||||
config.namespace = reference.namespace
|
||||
|
||||
impl: KVStore
|
||||
if isinstance(config, RedisKVStoreConfig):
|
||||
from .redis import RedisKVStoreImpl
|
||||
|
||||
impl = RedisKVStoreImpl(config)
|
||||
elif isinstance(config, SqliteKVStoreConfig):
|
||||
from .sqlite import SqliteKVStoreImpl
|
||||
|
||||
impl = SqliteKVStoreImpl(config)
|
||||
elif isinstance(config, PostgresKVStoreConfig):
|
||||
from .postgres import PostgresKVStoreImpl
|
||||
|
||||
impl = PostgresKVStoreImpl(config)
|
||||
elif isinstance(config, MongoDBKVStoreConfig):
|
||||
from .mongodb import MongoDBKVStoreImpl
|
||||
|
||||
impl = MongoDBKVStoreImpl(config)
|
||||
else:
|
||||
raise ValueError(f"Unknown kvstore type {config.type}")
|
||||
|
||||
await impl.initialize()
|
||||
_KVSTORE_INSTANCES[cache_key] = impl
|
||||
return impl
|
||||
9
src/llama_stack/core/storage/kvstore/mongodb/__init__.py
Normal file
9
src/llama_stack/core/storage/kvstore/mongodb/__init__.py
Normal file
|
|
@ -0,0 +1,9 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .mongodb import MongoDBKVStoreImpl
|
||||
|
||||
__all__ = ["MongoDBKVStoreImpl"]
|
||||
85
src/llama_stack/core/storage/kvstore/mongodb/mongodb.py
Normal file
85
src/llama_stack/core/storage/kvstore/mongodb/mongodb.py
Normal file
|
|
@ -0,0 +1,85 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from pymongo import AsyncMongoClient
|
||||
from pymongo.asynchronous.collection import AsyncCollection
|
||||
|
||||
from llama_stack.core.storage.kvstore import KVStore
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
from ..config import MongoDBKVStoreConfig
|
||||
|
||||
log = get_logger(name=__name__, category="providers::utils")
|
||||
|
||||
|
||||
class MongoDBKVStoreImpl(KVStore):
|
||||
def __init__(self, config: MongoDBKVStoreConfig):
|
||||
self.config = config
|
||||
self.conn: AsyncMongoClient | None = None
|
||||
|
||||
@property
|
||||
def collection(self) -> AsyncCollection:
|
||||
if self.conn is None:
|
||||
raise RuntimeError("MongoDB connection is not initialized")
|
||||
return self.conn[self.config.db][self.config.collection_name]
|
||||
|
||||
async def initialize(self) -> None:
|
||||
try:
|
||||
# Pass parameters explicitly to satisfy mypy - AsyncMongoClient doesn't accept **dict
|
||||
self.conn = AsyncMongoClient(
|
||||
host=self.config.host if self.config.host is not None else None,
|
||||
port=self.config.port if self.config.port is not None else None,
|
||||
username=self.config.user if self.config.user is not None else None,
|
||||
password=self.config.password if self.config.password is not None else None,
|
||||
)
|
||||
except Exception as e:
|
||||
log.exception("Could not connect to MongoDB database server")
|
||||
raise RuntimeError("Could not connect to MongoDB database server") from e
|
||||
|
||||
def _namespaced_key(self, key: str) -> str:
|
||||
if not self.config.namespace:
|
||||
return key
|
||||
return f"{self.config.namespace}:{key}"
|
||||
|
||||
async def set(self, key: str, value: str, expiration: datetime | None = None) -> None:
|
||||
key = self._namespaced_key(key)
|
||||
update_query = {"$set": {"value": value, "expiration": expiration}}
|
||||
await self.collection.update_one({"key": key}, update_query, upsert=True)
|
||||
|
||||
async def get(self, key: str) -> str | None:
|
||||
key = self._namespaced_key(key)
|
||||
query = {"key": key}
|
||||
result = await self.collection.find_one(query, {"value": 1, "_id": 0})
|
||||
return result["value"] if result else None
|
||||
|
||||
async def delete(self, key: str) -> None:
|
||||
key = self._namespaced_key(key)
|
||||
await self.collection.delete_one({"key": key})
|
||||
|
||||
async def values_in_range(self, start_key: str, end_key: str) -> list[str]:
|
||||
start_key = self._namespaced_key(start_key)
|
||||
end_key = self._namespaced_key(end_key)
|
||||
query = {
|
||||
"key": {"$gte": start_key, "$lt": end_key},
|
||||
}
|
||||
cursor = self.collection.find(query, {"value": 1, "_id": 0}).sort("key", 1)
|
||||
result = []
|
||||
async for doc in cursor:
|
||||
result.append(doc["value"])
|
||||
return result
|
||||
|
||||
async def keys_in_range(self, start_key: str, end_key: str) -> list[str]:
|
||||
start_key = self._namespaced_key(start_key)
|
||||
end_key = self._namespaced_key(end_key)
|
||||
query = {"key": {"$gte": start_key, "$lt": end_key}}
|
||||
cursor = self.collection.find(query, {"key": 1, "_id": 0}).sort("key", 1)
|
||||
# AsyncCursor requires async iteration
|
||||
result = []
|
||||
async for doc in cursor:
|
||||
result.append(doc["key"])
|
||||
return result
|
||||
|
|
@ -0,0 +1,7 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .postgres import PostgresKVStoreImpl # noqa: F401 F403
|
||||
125
src/llama_stack/core/storage/kvstore/postgres/postgres.py
Normal file
125
src/llama_stack/core/storage/kvstore/postgres/postgres.py
Normal file
|
|
@ -0,0 +1,125 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
import psycopg2 # type: ignore[import-not-found]
|
||||
from psycopg2.extensions import connection as PGConnection # type: ignore[import-not-found]
|
||||
from psycopg2.extras import DictCursor # type: ignore[import-not-found]
|
||||
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack_api.internal.kvstore import KVStore
|
||||
|
||||
from ..config import PostgresKVStoreConfig
|
||||
|
||||
log = get_logger(name=__name__, category="providers::utils")
|
||||
|
||||
|
||||
class PostgresKVStoreImpl(KVStore):
|
||||
def __init__(self, config: PostgresKVStoreConfig):
|
||||
self.config = config
|
||||
self._conn: PGConnection | None = None
|
||||
self._cursor: DictCursor | None = None
|
||||
|
||||
async def initialize(self) -> None:
|
||||
try:
|
||||
self._conn = psycopg2.connect(
|
||||
host=self.config.host,
|
||||
port=self.config.port,
|
||||
database=self.config.db,
|
||||
user=self.config.user,
|
||||
password=self.config.password,
|
||||
sslmode=self.config.ssl_mode,
|
||||
sslrootcert=self.config.ca_cert_path,
|
||||
)
|
||||
self._conn.autocommit = True
|
||||
self._cursor = self._conn.cursor(cursor_factory=DictCursor)
|
||||
|
||||
# Create table if it doesn't exist
|
||||
self._cursor.execute(
|
||||
f"""
|
||||
CREATE TABLE IF NOT EXISTS {self.config.table_name} (
|
||||
key TEXT PRIMARY KEY,
|
||||
value TEXT,
|
||||
expiration TIMESTAMP
|
||||
)
|
||||
"""
|
||||
)
|
||||
except Exception as e:
|
||||
log.exception("Could not connect to PostgreSQL database server")
|
||||
raise RuntimeError("Could not connect to PostgreSQL database server") from e
|
||||
|
||||
def _cursor_or_raise(self) -> DictCursor:
|
||||
if self._cursor is None:
|
||||
raise RuntimeError("Postgres client not initialized")
|
||||
return self._cursor
|
||||
|
||||
def _namespaced_key(self, key: str) -> str:
|
||||
if not self.config.namespace:
|
||||
return key
|
||||
return f"{self.config.namespace}:{key}"
|
||||
|
||||
async def set(self, key: str, value: str, expiration: datetime | None = None) -> None:
|
||||
key = self._namespaced_key(key)
|
||||
cursor = self._cursor_or_raise()
|
||||
cursor.execute(
|
||||
f"""
|
||||
INSERT INTO {self.config.table_name} (key, value, expiration)
|
||||
VALUES (%s, %s, %s)
|
||||
ON CONFLICT (key) DO UPDATE
|
||||
SET value = EXCLUDED.value, expiration = EXCLUDED.expiration
|
||||
""",
|
||||
(key, value, expiration),
|
||||
)
|
||||
|
||||
async def get(self, key: str) -> str | None:
|
||||
key = self._namespaced_key(key)
|
||||
cursor = self._cursor_or_raise()
|
||||
cursor.execute(
|
||||
f"""
|
||||
SELECT value FROM {self.config.table_name}
|
||||
WHERE key = %s
|
||||
AND (expiration IS NULL OR expiration > NOW())
|
||||
""",
|
||||
(key,),
|
||||
)
|
||||
result = cursor.fetchone()
|
||||
return result[0] if result else None
|
||||
|
||||
async def delete(self, key: str) -> None:
|
||||
key = self._namespaced_key(key)
|
||||
cursor = self._cursor_or_raise()
|
||||
cursor.execute(
|
||||
f"DELETE FROM {self.config.table_name} WHERE key = %s",
|
||||
(key,),
|
||||
)
|
||||
|
||||
async def values_in_range(self, start_key: str, end_key: str) -> list[str]:
|
||||
start_key = self._namespaced_key(start_key)
|
||||
end_key = self._namespaced_key(end_key)
|
||||
|
||||
cursor = self._cursor_or_raise()
|
||||
cursor.execute(
|
||||
f"""
|
||||
SELECT value FROM {self.config.table_name}
|
||||
WHERE key >= %s AND key < %s
|
||||
AND (expiration IS NULL OR expiration > NOW())
|
||||
ORDER BY key
|
||||
""",
|
||||
(start_key, end_key),
|
||||
)
|
||||
return [row[0] for row in cursor.fetchall()]
|
||||
|
||||
async def keys_in_range(self, start_key: str, end_key: str) -> list[str]:
|
||||
start_key = self._namespaced_key(start_key)
|
||||
end_key = self._namespaced_key(end_key)
|
||||
|
||||
cursor = self._cursor_or_raise()
|
||||
cursor.execute(
|
||||
f"SELECT key FROM {self.config.table_name} WHERE key >= %s AND key < %s",
|
||||
(start_key, end_key),
|
||||
)
|
||||
return [row[0] for row in cursor.fetchall()]
|
||||
7
src/llama_stack/core/storage/kvstore/redis/__init__.py
Normal file
7
src/llama_stack/core/storage/kvstore/redis/__init__.py
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .redis import RedisKVStoreImpl # noqa: F401
|
||||
101
src/llama_stack/core/storage/kvstore/redis/redis.py
Normal file
101
src/llama_stack/core/storage/kvstore/redis/redis.py
Normal file
|
|
@ -0,0 +1,101 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from redis.asyncio import Redis # type: ignore[import-not-found]
|
||||
|
||||
from llama_stack_api.internal.kvstore import KVStore
|
||||
|
||||
from ..config import RedisKVStoreConfig
|
||||
|
||||
|
||||
class RedisKVStoreImpl(KVStore):
|
||||
def __init__(self, config: RedisKVStoreConfig):
|
||||
self.config = config
|
||||
self._redis: Redis | None = None
|
||||
|
||||
async def initialize(self) -> None:
|
||||
self._redis = Redis.from_url(self.config.url)
|
||||
|
||||
def _client(self) -> Redis:
|
||||
if self._redis is None:
|
||||
raise RuntimeError("Redis client not initialized")
|
||||
return self._redis
|
||||
|
||||
def _namespaced_key(self, key: str) -> str:
|
||||
if not self.config.namespace:
|
||||
return key
|
||||
return f"{self.config.namespace}:{key}"
|
||||
|
||||
async def set(self, key: str, value: str, expiration: datetime | None = None) -> None:
|
||||
key = self._namespaced_key(key)
|
||||
client = self._client()
|
||||
await client.set(key, value)
|
||||
if expiration:
|
||||
await client.expireat(key, expiration)
|
||||
|
||||
async def get(self, key: str) -> str | None:
|
||||
key = self._namespaced_key(key)
|
||||
client = self._client()
|
||||
value = await client.get(key)
|
||||
if value is None:
|
||||
return None
|
||||
await client.ttl(key)
|
||||
if isinstance(value, bytes):
|
||||
return value.decode("utf-8")
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
return str(value)
|
||||
|
||||
async def delete(self, key: str) -> None:
|
||||
key = self._namespaced_key(key)
|
||||
await self._client().delete(key)
|
||||
|
||||
async def values_in_range(self, start_key: str, end_key: str) -> list[str]:
|
||||
start_key = self._namespaced_key(start_key)
|
||||
end_key = self._namespaced_key(end_key)
|
||||
client = self._client()
|
||||
cursor = 0
|
||||
pattern = start_key + "*" # Match all keys starting with start_key prefix
|
||||
matching_keys: list[str | bytes] = []
|
||||
while True:
|
||||
cursor, keys = await client.scan(cursor, match=pattern, count=1000)
|
||||
|
||||
for key in keys:
|
||||
key_str = key.decode("utf-8") if isinstance(key, bytes) else key
|
||||
if start_key <= key_str <= end_key:
|
||||
matching_keys.append(key)
|
||||
|
||||
if cursor == 0:
|
||||
break
|
||||
|
||||
# Then fetch all values in a single MGET call
|
||||
if matching_keys:
|
||||
values = await client.mget(matching_keys)
|
||||
return [
|
||||
value.decode("utf-8") if isinstance(value, bytes) else value for value in values if value is not None
|
||||
]
|
||||
|
||||
return []
|
||||
|
||||
async def keys_in_range(self, start_key: str, end_key: str) -> list[str]:
|
||||
"""Get all keys in the given range."""
|
||||
start_key = self._namespaced_key(start_key)
|
||||
end_key = self._namespaced_key(end_key)
|
||||
client = self._client()
|
||||
cursor = 0
|
||||
pattern = start_key + "*"
|
||||
result: list[str] = []
|
||||
while True:
|
||||
cursor, keys = await client.scan(cursor, match=pattern, count=1000)
|
||||
for key in keys:
|
||||
key_str = key.decode("utf-8") if isinstance(key, bytes) else str(key)
|
||||
if start_key <= key_str <= end_key:
|
||||
result.append(key_str)
|
||||
if cursor == 0:
|
||||
break
|
||||
return result
|
||||
7
src/llama_stack/core/storage/kvstore/sqlite/__init__.py
Normal file
7
src/llama_stack/core/storage/kvstore/sqlite/__init__.py
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .sqlite import SqliteKVStoreImpl # noqa: F401
|
||||
174
src/llama_stack/core/storage/kvstore/sqlite/sqlite.py
Normal file
174
src/llama_stack/core/storage/kvstore/sqlite/sqlite.py
Normal file
|
|
@ -0,0 +1,174 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
from datetime import datetime
|
||||
|
||||
import aiosqlite
|
||||
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack_api.internal.kvstore import KVStore
|
||||
|
||||
from ..config import SqliteKVStoreConfig
|
||||
|
||||
logger = get_logger(name=__name__, category="providers::utils")
|
||||
|
||||
|
||||
class SqliteKVStoreImpl(KVStore):
|
||||
def __init__(self, config: SqliteKVStoreConfig):
|
||||
self.db_path = config.db_path
|
||||
self.table_name = "kvstore"
|
||||
self._conn: aiosqlite.Connection | None = None
|
||||
|
||||
def __str__(self):
|
||||
return f"SqliteKVStoreImpl(db_path={self.db_path}, table_name={self.table_name})"
|
||||
|
||||
def _is_memory_db(self) -> bool:
|
||||
"""Check if this is an in-memory database."""
|
||||
return self.db_path == ":memory:" or "mode=memory" in self.db_path
|
||||
|
||||
async def initialize(self):
|
||||
# Skip directory creation for in-memory databases and file: URIs
|
||||
if not self._is_memory_db() and not self.db_path.startswith("file:"):
|
||||
db_dir = os.path.dirname(self.db_path)
|
||||
if db_dir: # Only create if there's a directory component
|
||||
os.makedirs(db_dir, exist_ok=True)
|
||||
|
||||
# Only use persistent connection for in-memory databases
|
||||
# File-based databases use connection-per-operation to avoid hangs
|
||||
if self._is_memory_db():
|
||||
self._conn = await aiosqlite.connect(self.db_path)
|
||||
await self._conn.execute(
|
||||
f"""
|
||||
CREATE TABLE IF NOT EXISTS {self.table_name} (
|
||||
key TEXT PRIMARY KEY,
|
||||
value TEXT,
|
||||
expiration TIMESTAMP
|
||||
)
|
||||
"""
|
||||
)
|
||||
await self._conn.commit()
|
||||
else:
|
||||
# For file-based databases, just create the table
|
||||
async with aiosqlite.connect(self.db_path) as db:
|
||||
await db.execute(
|
||||
f"""
|
||||
CREATE TABLE IF NOT EXISTS {self.table_name} (
|
||||
key TEXT PRIMARY KEY,
|
||||
value TEXT,
|
||||
expiration TIMESTAMP
|
||||
)
|
||||
"""
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
async def shutdown(self):
|
||||
"""Close the persistent connection (only for in-memory databases)."""
|
||||
if self._conn:
|
||||
await self._conn.close()
|
||||
self._conn = None
|
||||
|
||||
async def set(self, key: str, value: str, expiration: datetime | None = None) -> None:
|
||||
if self._conn:
|
||||
# In-memory database with persistent connection
|
||||
await self._conn.execute(
|
||||
f"INSERT OR REPLACE INTO {self.table_name} (key, value, expiration) VALUES (?, ?, ?)",
|
||||
(key, value, expiration),
|
||||
)
|
||||
await self._conn.commit()
|
||||
else:
|
||||
# File-based database with connection per operation
|
||||
async with aiosqlite.connect(self.db_path) as db:
|
||||
await db.execute(
|
||||
f"INSERT OR REPLACE INTO {self.table_name} (key, value, expiration) VALUES (?, ?, ?)",
|
||||
(key, value, expiration),
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
async def get(self, key: str) -> str | None:
|
||||
if self._conn:
|
||||
# In-memory database with persistent connection
|
||||
async with self._conn.execute(
|
||||
f"SELECT value, expiration FROM {self.table_name} WHERE key = ?", (key,)
|
||||
) as cursor:
|
||||
row = await cursor.fetchone()
|
||||
if row is None:
|
||||
return None
|
||||
value, expiration = row
|
||||
if not isinstance(value, str):
|
||||
logger.warning(f"Expected string value for key {key}, got {type(value)}, returning None")
|
||||
return None
|
||||
return value
|
||||
else:
|
||||
# File-based database with connection per operation
|
||||
async with aiosqlite.connect(self.db_path) as db:
|
||||
async with db.execute(
|
||||
f"SELECT value, expiration FROM {self.table_name} WHERE key = ?", (key,)
|
||||
) as cursor:
|
||||
row = await cursor.fetchone()
|
||||
if row is None:
|
||||
return None
|
||||
value, expiration = row
|
||||
if not isinstance(value, str):
|
||||
logger.warning(f"Expected string value for key {key}, got {type(value)}, returning None")
|
||||
return None
|
||||
return value
|
||||
|
||||
async def delete(self, key: str) -> None:
|
||||
if self._conn:
|
||||
# In-memory database with persistent connection
|
||||
await self._conn.execute(f"DELETE FROM {self.table_name} WHERE key = ?", (key,))
|
||||
await self._conn.commit()
|
||||
else:
|
||||
# File-based database with connection per operation
|
||||
async with aiosqlite.connect(self.db_path) as db:
|
||||
await db.execute(f"DELETE FROM {self.table_name} WHERE key = ?", (key,))
|
||||
await db.commit()
|
||||
|
||||
async def values_in_range(self, start_key: str, end_key: str) -> list[str]:
|
||||
if self._conn:
|
||||
# In-memory database with persistent connection
|
||||
async with self._conn.execute(
|
||||
f"SELECT key, value, expiration FROM {self.table_name} WHERE key >= ? AND key <= ?",
|
||||
(start_key, end_key),
|
||||
) as cursor:
|
||||
result = []
|
||||
async for row in cursor:
|
||||
_, value, _ = row
|
||||
result.append(value)
|
||||
return result
|
||||
else:
|
||||
# File-based database with connection per operation
|
||||
async with aiosqlite.connect(self.db_path) as db:
|
||||
async with db.execute(
|
||||
f"SELECT key, value, expiration FROM {self.table_name} WHERE key >= ? AND key <= ?",
|
||||
(start_key, end_key),
|
||||
) as cursor:
|
||||
result = []
|
||||
async for row in cursor:
|
||||
_, value, _ = row
|
||||
result.append(value)
|
||||
return result
|
||||
|
||||
async def keys_in_range(self, start_key: str, end_key: str) -> list[str]:
|
||||
"""Get all keys in the given range."""
|
||||
if self._conn:
|
||||
# In-memory database with persistent connection
|
||||
cursor = await self._conn.execute(
|
||||
f"SELECT key FROM {self.table_name} WHERE key >= ? AND key <= ?",
|
||||
(start_key, end_key),
|
||||
)
|
||||
rows = await cursor.fetchall()
|
||||
return [row[0] for row in rows]
|
||||
else:
|
||||
# File-based database with connection per operation
|
||||
async with aiosqlite.connect(self.db_path) as db:
|
||||
cursor = await db.execute(
|
||||
f"SELECT key FROM {self.table_name} WHERE key >= ? AND key <= ?",
|
||||
(start_key, end_key),
|
||||
)
|
||||
rows = await cursor.fetchall()
|
||||
return [row[0] for row in rows]
|
||||
17
src/llama_stack/core/storage/sqlstore/__init__.py
Normal file
17
src/llama_stack/core/storage/sqlstore/__init__.py
Normal file
|
|
@ -0,0 +1,17 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack_api.internal.sqlstore import (
|
||||
ColumnDefinition as ColumnDefinition,
|
||||
)
|
||||
from llama_stack_api.internal.sqlstore import (
|
||||
ColumnType as ColumnType,
|
||||
)
|
||||
from llama_stack_api.internal.sqlstore import (
|
||||
SqlStore as SqlStore,
|
||||
)
|
||||
|
||||
from .sqlstore import * # noqa: F401,F403
|
||||
336
src/llama_stack/core/storage/sqlstore/authorized_sqlstore.py
Normal file
336
src/llama_stack/core/storage/sqlstore/authorized_sqlstore.py
Normal file
|
|
@ -0,0 +1,336 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, Literal
|
||||
|
||||
from llama_stack.core.access_control.access_control import default_policy, is_action_allowed
|
||||
from llama_stack.core.access_control.conditions import ProtectedResource
|
||||
from llama_stack.core.access_control.datatypes import AccessRule, Action, Scope
|
||||
from llama_stack.core.datatypes import User
|
||||
from llama_stack.core.request_headers import get_authenticated_user
|
||||
from llama_stack.core.storage.datatypes import StorageBackendType
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack_api import PaginatedResponse
|
||||
from llama_stack_api.internal.sqlstore import ColumnDefinition, ColumnType, SqlStore
|
||||
|
||||
logger = get_logger(name=__name__, category="providers::utils")
|
||||
|
||||
# Hardcoded copy of the default policy that our SQL filtering implements
|
||||
# WARNING: If default_policy() changes, this constant must be updated accordingly
|
||||
# or SQL filtering will fall back to conservative mode (safe but less performant)
|
||||
#
|
||||
# This policy represents: "Permit all actions when user is in owners list for ALL attribute categories"
|
||||
# The corresponding SQL logic is implemented in _build_default_policy_where_clause():
|
||||
# - Public records (no access_attributes) are always accessible
|
||||
# - Records with access_attributes require user to match ALL categories that exist in the resource
|
||||
# - Missing categories in the resource are treated as "no restriction" (allow)
|
||||
# - Within each category, user needs ANY matching value (OR logic)
|
||||
# - Between categories, user needs ALL categories to match (AND logic)
|
||||
SQL_OPTIMIZED_POLICY = [
|
||||
AccessRule(
|
||||
permit=Scope(actions=list(Action)),
|
||||
when=["user in owners roles", "user in owners teams", "user in owners projects", "user in owners namespaces"],
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def _enhance_item_with_access_control(item: Mapping[str, Any], current_user: User | None) -> Mapping[str, Any]:
|
||||
"""Add access control attributes to a data item."""
|
||||
enhanced = dict(item)
|
||||
if current_user:
|
||||
enhanced["owner_principal"] = current_user.principal
|
||||
enhanced["access_attributes"] = current_user.attributes
|
||||
else:
|
||||
# IMPORTANT: Use empty string and null value (not None) to match public access filter
|
||||
# The public access filter in _get_public_access_conditions() expects:
|
||||
# - owner_principal = '' (empty string)
|
||||
# - access_attributes = null (JSON null, which serializes to the string 'null')
|
||||
# Setting them to None (SQL NULL) will cause rows to be filtered out on read.
|
||||
enhanced["owner_principal"] = ""
|
||||
enhanced["access_attributes"] = None # Pydantic/JSON will serialize this as JSON null
|
||||
return enhanced
|
||||
|
||||
|
||||
class SqlRecord(ProtectedResource):
|
||||
def __init__(self, record_id: str, table_name: str, owner: User):
|
||||
self.type = f"sql_record::{table_name}"
|
||||
self.identifier = record_id
|
||||
self.owner = owner
|
||||
|
||||
|
||||
class AuthorizedSqlStore:
|
||||
"""
|
||||
Authorization layer for SqlStore that provides access control functionality.
|
||||
|
||||
This class composes a base SqlStore and adds authorization methods that handle
|
||||
access control policies, user attribute capture, and SQL filtering optimization.
|
||||
"""
|
||||
|
||||
def __init__(self, sql_store: SqlStore, policy: list[AccessRule]):
|
||||
"""
|
||||
Initialize the authorization layer.
|
||||
|
||||
:param sql_store: Base SqlStore implementation to wrap
|
||||
:param policy: Access control policy to use for authorization
|
||||
"""
|
||||
self.sql_store = sql_store
|
||||
self.policy = policy
|
||||
self._detect_database_type()
|
||||
self._validate_sql_optimized_policy()
|
||||
|
||||
def _detect_database_type(self) -> None:
|
||||
"""Detect the database type from the underlying SQL store."""
|
||||
if not hasattr(self.sql_store, "config"):
|
||||
raise ValueError("SqlStore must have a config attribute to be used with AuthorizedSqlStore")
|
||||
|
||||
self.database_type = self.sql_store.config.type.value
|
||||
if self.database_type not in [StorageBackendType.SQL_POSTGRES.value, StorageBackendType.SQL_SQLITE.value]:
|
||||
raise ValueError(f"Unsupported database type: {self.database_type}")
|
||||
|
||||
def _validate_sql_optimized_policy(self) -> None:
|
||||
"""Validate that SQL_OPTIMIZED_POLICY matches the actual default_policy().
|
||||
|
||||
This ensures that if default_policy() changes, we detect the mismatch and
|
||||
can update our SQL filtering logic accordingly.
|
||||
"""
|
||||
actual_default = default_policy()
|
||||
|
||||
if SQL_OPTIMIZED_POLICY != actual_default:
|
||||
logger.warning(
|
||||
f"SQL_OPTIMIZED_POLICY does not match default_policy(). "
|
||||
f"SQL filtering will use conservative mode. "
|
||||
f"Expected: {SQL_OPTIMIZED_POLICY}, Got: {actual_default}",
|
||||
)
|
||||
|
||||
async def create_table(self, table: str, schema: Mapping[str, ColumnType | ColumnDefinition]) -> None:
|
||||
"""Create a table with built-in access control support."""
|
||||
|
||||
enhanced_schema = dict(schema)
|
||||
if "access_attributes" not in enhanced_schema:
|
||||
enhanced_schema["access_attributes"] = ColumnType.JSON
|
||||
if "owner_principal" not in enhanced_schema:
|
||||
enhanced_schema["owner_principal"] = ColumnType.STRING
|
||||
|
||||
await self.sql_store.create_table(table, enhanced_schema)
|
||||
await self.sql_store.add_column_if_not_exists(table, "access_attributes", ColumnType.JSON)
|
||||
await self.sql_store.add_column_if_not_exists(table, "owner_principal", ColumnType.STRING)
|
||||
|
||||
async def insert(self, table: str, data: Mapping[str, Any] | Sequence[Mapping[str, Any]]) -> None:
|
||||
"""Insert a row or batch of rows with automatic access control attribute capture."""
|
||||
current_user = get_authenticated_user()
|
||||
enhanced_data: Mapping[str, Any] | Sequence[Mapping[str, Any]]
|
||||
if isinstance(data, Mapping):
|
||||
enhanced_data = _enhance_item_with_access_control(data, current_user)
|
||||
else:
|
||||
enhanced_data = [_enhance_item_with_access_control(item, current_user) for item in data]
|
||||
await self.sql_store.insert(table, enhanced_data)
|
||||
|
||||
async def upsert(
|
||||
self,
|
||||
table: str,
|
||||
data: Mapping[str, Any],
|
||||
conflict_columns: list[str],
|
||||
update_columns: list[str] | None = None,
|
||||
) -> None:
|
||||
"""Upsert a row with automatic access control attribute capture."""
|
||||
current_user = get_authenticated_user()
|
||||
enhanced_data = _enhance_item_with_access_control(data, current_user)
|
||||
await self.sql_store.upsert(
|
||||
table=table,
|
||||
data=enhanced_data,
|
||||
conflict_columns=conflict_columns,
|
||||
update_columns=update_columns,
|
||||
)
|
||||
|
||||
async def fetch_all(
|
||||
self,
|
||||
table: str,
|
||||
where: Mapping[str, Any] | None = None,
|
||||
limit: int | None = None,
|
||||
order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None,
|
||||
cursor: tuple[str, str] | None = None,
|
||||
) -> PaginatedResponse:
|
||||
"""Fetch all rows with automatic access control filtering."""
|
||||
access_where = self._build_access_control_where_clause(self.policy)
|
||||
rows = await self.sql_store.fetch_all(
|
||||
table=table,
|
||||
where=where,
|
||||
where_sql=access_where,
|
||||
limit=limit,
|
||||
order_by=order_by,
|
||||
cursor=cursor,
|
||||
)
|
||||
|
||||
current_user = get_authenticated_user()
|
||||
filtered_rows = []
|
||||
|
||||
for row in rows.data:
|
||||
stored_access_attrs = row.get("access_attributes")
|
||||
stored_owner_principal = row.get("owner_principal") or ""
|
||||
|
||||
record_id = row.get("id", "unknown")
|
||||
sql_record = SqlRecord(
|
||||
str(record_id), table, User(principal=stored_owner_principal, attributes=stored_access_attrs)
|
||||
)
|
||||
|
||||
if is_action_allowed(self.policy, Action.READ, sql_record, current_user):
|
||||
filtered_rows.append(row)
|
||||
|
||||
return PaginatedResponse(
|
||||
data=filtered_rows,
|
||||
has_more=rows.has_more,
|
||||
)
|
||||
|
||||
async def fetch_one(
|
||||
self,
|
||||
table: str,
|
||||
where: Mapping[str, Any] | None = None,
|
||||
order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None,
|
||||
) -> dict[str, Any] | None:
|
||||
"""Fetch one row with automatic access control checking."""
|
||||
results = await self.fetch_all(
|
||||
table=table,
|
||||
where=where,
|
||||
limit=1,
|
||||
order_by=order_by,
|
||||
)
|
||||
|
||||
return results.data[0] if results.data else None
|
||||
|
||||
async def update(self, table: str, data: Mapping[str, Any], where: Mapping[str, Any]) -> None:
|
||||
"""Update rows with automatic access control attribute capture."""
|
||||
enhanced_data = dict(data)
|
||||
|
||||
current_user = get_authenticated_user()
|
||||
if current_user:
|
||||
enhanced_data["owner_principal"] = current_user.principal
|
||||
enhanced_data["access_attributes"] = current_user.attributes
|
||||
else:
|
||||
# IMPORTANT: Use empty string for owner_principal to match public access filter
|
||||
enhanced_data["owner_principal"] = ""
|
||||
enhanced_data["access_attributes"] = None # Will serialize as JSON null
|
||||
|
||||
await self.sql_store.update(table, enhanced_data, where)
|
||||
|
||||
async def delete(self, table: str, where: Mapping[str, Any]) -> None:
|
||||
"""Delete rows with automatic access control filtering."""
|
||||
await self.sql_store.delete(table, where)
|
||||
|
||||
def _build_access_control_where_clause(self, policy: list[AccessRule]) -> str:
|
||||
"""Build SQL WHERE clause for access control filtering.
|
||||
|
||||
Only applies SQL filtering for the default policy to ensure correctness.
|
||||
For custom policies, uses conservative filtering to avoid blocking legitimate access.
|
||||
"""
|
||||
current_user = get_authenticated_user()
|
||||
|
||||
if not policy or policy == SQL_OPTIMIZED_POLICY:
|
||||
return self._build_default_policy_where_clause(current_user)
|
||||
else:
|
||||
return self._build_conservative_where_clause()
|
||||
|
||||
def _json_extract(self, column: str, path: str) -> str:
|
||||
"""Extract JSON value (keeping JSON type).
|
||||
|
||||
Args:
|
||||
column: The JSON column name
|
||||
path: The JSON path (e.g., 'roles', 'teams')
|
||||
|
||||
Returns:
|
||||
SQL expression to extract JSON value
|
||||
"""
|
||||
if self.database_type == StorageBackendType.SQL_POSTGRES.value:
|
||||
return f"{column}->'{path}'"
|
||||
elif self.database_type == StorageBackendType.SQL_SQLITE.value:
|
||||
return f"JSON_EXTRACT({column}, '$.{path}')"
|
||||
else:
|
||||
raise ValueError(f"Unsupported database type: {self.database_type}")
|
||||
|
||||
def _json_extract_text(self, column: str, path: str) -> str:
|
||||
"""Extract JSON value as text.
|
||||
|
||||
Args:
|
||||
column: The JSON column name
|
||||
path: The JSON path (e.g., 'roles', 'teams')
|
||||
|
||||
Returns:
|
||||
SQL expression to extract JSON value as text
|
||||
"""
|
||||
if self.database_type == StorageBackendType.SQL_POSTGRES.value:
|
||||
return f"{column}->>'{path}'"
|
||||
elif self.database_type == StorageBackendType.SQL_SQLITE.value:
|
||||
return f"JSON_EXTRACT({column}, '$.{path}')"
|
||||
else:
|
||||
raise ValueError(f"Unsupported database type: {self.database_type}")
|
||||
|
||||
def _get_public_access_conditions(self) -> list[str]:
|
||||
"""Get the SQL conditions for public access.
|
||||
|
||||
Public records are those with:
|
||||
- owner_principal = '' (empty string)
|
||||
- access_attributes is either SQL NULL or JSON null
|
||||
|
||||
Note: Different databases serialize None differently:
|
||||
- SQLite: None → JSON null (text = 'null')
|
||||
- Postgres: None → SQL NULL (IS NULL)
|
||||
"""
|
||||
conditions = ["owner_principal = ''"]
|
||||
if self.database_type == StorageBackendType.SQL_POSTGRES.value:
|
||||
# Accept both SQL NULL and JSON null for Postgres compatibility
|
||||
# This handles both old rows (SQL NULL) and new rows (JSON null)
|
||||
conditions.append("(access_attributes IS NULL OR access_attributes::text = 'null')")
|
||||
elif self.database_type == StorageBackendType.SQL_SQLITE.value:
|
||||
# SQLite serializes None as JSON null
|
||||
conditions.append("(access_attributes IS NULL OR access_attributes = 'null')")
|
||||
else:
|
||||
raise ValueError(f"Unsupported database type: {self.database_type}")
|
||||
return conditions
|
||||
|
||||
def _build_default_policy_where_clause(self, current_user: User | None) -> str:
|
||||
"""Build SQL WHERE clause for the default policy.
|
||||
|
||||
Default policy: permit all actions when user in owners [roles, teams, projects, namespaces]
|
||||
This means user must match ALL attribute categories that exist in the resource.
|
||||
"""
|
||||
base_conditions = self._get_public_access_conditions()
|
||||
user_attr_conditions = []
|
||||
|
||||
if current_user and current_user.attributes:
|
||||
for attr_key, user_values in current_user.attributes.items():
|
||||
if user_values:
|
||||
value_conditions = []
|
||||
for value in user_values:
|
||||
# Check if JSON array contains the value
|
||||
escaped_value = value.replace("'", "''")
|
||||
json_text = self._json_extract_text("access_attributes", attr_key)
|
||||
value_conditions.append(f"({json_text} LIKE '%\"{escaped_value}\"%')")
|
||||
|
||||
if value_conditions:
|
||||
# Check if the category is missing (NULL)
|
||||
category_missing = f"{self._json_extract('access_attributes', attr_key)} IS NULL"
|
||||
user_matches_category = f"({' OR '.join(value_conditions)})"
|
||||
user_attr_conditions.append(f"({category_missing} OR {user_matches_category})")
|
||||
|
||||
if user_attr_conditions:
|
||||
all_requirements_met = f"({' AND '.join(user_attr_conditions)})"
|
||||
base_conditions.append(all_requirements_met)
|
||||
|
||||
return f"({' OR '.join(base_conditions)})"
|
||||
|
||||
def _build_conservative_where_clause(self) -> str:
|
||||
"""Conservative SQL filtering for custom policies.
|
||||
|
||||
Only filters records we're 100% certain would be denied by any reasonable policy.
|
||||
"""
|
||||
current_user = get_authenticated_user()
|
||||
|
||||
if not current_user:
|
||||
# Only allow public records
|
||||
base_conditions = self._get_public_access_conditions()
|
||||
return f"({' OR '.join(base_conditions)})"
|
||||
|
||||
return "1=1"
|
||||
373
src/llama_stack/core/storage/sqlstore/sqlalchemy_sqlstore.py
Normal file
373
src/llama_stack/core/storage/sqlstore/sqlalchemy_sqlstore.py
Normal file
|
|
@ -0,0 +1,373 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, Literal, cast
|
||||
|
||||
from sqlalchemy import (
|
||||
JSON,
|
||||
Boolean,
|
||||
Column,
|
||||
DateTime,
|
||||
Float,
|
||||
Integer,
|
||||
MetaData,
|
||||
String,
|
||||
Table,
|
||||
Text,
|
||||
event,
|
||||
inspect,
|
||||
select,
|
||||
text,
|
||||
)
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.ext.asyncio.engine import AsyncEngine
|
||||
from sqlalchemy.sql.elements import ColumnElement
|
||||
|
||||
from llama_stack.core.storage.datatypes import SqlAlchemySqlStoreConfig
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack_api import PaginatedResponse
|
||||
from llama_stack_api.internal.sqlstore import ColumnDefinition, ColumnType, SqlStore
|
||||
|
||||
logger = get_logger(name=__name__, category="providers::utils")
|
||||
|
||||
TYPE_MAPPING: dict[ColumnType, Any] = {
|
||||
ColumnType.INTEGER: Integer,
|
||||
ColumnType.STRING: String,
|
||||
ColumnType.FLOAT: Float,
|
||||
ColumnType.BOOLEAN: Boolean,
|
||||
ColumnType.DATETIME: DateTime,
|
||||
ColumnType.TEXT: Text,
|
||||
ColumnType.JSON: JSON,
|
||||
}
|
||||
|
||||
|
||||
def _build_where_expr(column: ColumnElement, value: Any) -> ColumnElement:
|
||||
"""Return a SQLAlchemy expression for a where condition.
|
||||
|
||||
`value` may be a simple scalar (equality) or a mapping like {">": 123}.
|
||||
The returned expression is a SQLAlchemy ColumnElement usable in query.where(...).
|
||||
"""
|
||||
if isinstance(value, Mapping):
|
||||
if len(value) != 1:
|
||||
raise ValueError(f"Operator mapping must have a single operator, got: {value}")
|
||||
op, operand = next(iter(value.items()))
|
||||
if op == "==" or op == "=":
|
||||
return cast(ColumnElement[Any], column == operand)
|
||||
if op == ">":
|
||||
return cast(ColumnElement[Any], column > operand)
|
||||
if op == "<":
|
||||
return cast(ColumnElement[Any], column < operand)
|
||||
if op == ">=":
|
||||
return cast(ColumnElement[Any], column >= operand)
|
||||
if op == "<=":
|
||||
return cast(ColumnElement[Any], column <= operand)
|
||||
raise ValueError(f"Unsupported operator '{op}' in where mapping")
|
||||
return cast(ColumnElement[Any], column == value)
|
||||
|
||||
|
||||
class SqlAlchemySqlStoreImpl(SqlStore):
|
||||
def __init__(self, config: SqlAlchemySqlStoreConfig):
|
||||
self.config = config
|
||||
self._is_sqlite_backend = "sqlite" in self.config.engine_str
|
||||
self.async_session = async_sessionmaker(self.create_engine())
|
||||
self.metadata = MetaData()
|
||||
|
||||
def create_engine(self) -> AsyncEngine:
|
||||
# Configure connection args for better concurrency support
|
||||
connect_args = {}
|
||||
if self._is_sqlite_backend:
|
||||
# SQLite-specific optimizations for concurrent access
|
||||
# With WAL mode, most locks resolve in milliseconds, but allow up to 5s for edge cases
|
||||
connect_args["timeout"] = 5.0
|
||||
connect_args["check_same_thread"] = False # Allow usage across asyncio tasks
|
||||
|
||||
engine = create_async_engine(
|
||||
self.config.engine_str,
|
||||
pool_pre_ping=True,
|
||||
connect_args=connect_args,
|
||||
)
|
||||
|
||||
# Enable WAL mode for SQLite to support concurrent readers and writers
|
||||
if self._is_sqlite_backend:
|
||||
|
||||
@event.listens_for(engine.sync_engine, "connect")
|
||||
def set_sqlite_pragma(dbapi_conn, connection_record):
|
||||
cursor = dbapi_conn.cursor()
|
||||
# Enable Write-Ahead Logging for better concurrency
|
||||
cursor.execute("PRAGMA journal_mode=WAL")
|
||||
# Set busy timeout to 5 seconds (retry instead of immediate failure)
|
||||
# With WAL mode, locks should be brief; if we hit 5s there's a bigger issue
|
||||
cursor.execute("PRAGMA busy_timeout=5000")
|
||||
# Use NORMAL synchronous mode for better performance (still safe with WAL)
|
||||
cursor.execute("PRAGMA synchronous=NORMAL")
|
||||
cursor.close()
|
||||
|
||||
return engine
|
||||
|
||||
async def create_table(
|
||||
self,
|
||||
table: str,
|
||||
schema: Mapping[str, ColumnType | ColumnDefinition],
|
||||
) -> None:
|
||||
if not schema:
|
||||
raise ValueError(f"No columns defined for table '{table}'.")
|
||||
|
||||
sqlalchemy_columns: list[Column] = []
|
||||
|
||||
for col_name, col_props in schema.items():
|
||||
col_type = None
|
||||
is_primary_key = False
|
||||
is_nullable = True
|
||||
|
||||
if isinstance(col_props, ColumnType):
|
||||
col_type = col_props
|
||||
elif isinstance(col_props, ColumnDefinition):
|
||||
col_type = col_props.type
|
||||
is_primary_key = col_props.primary_key
|
||||
is_nullable = col_props.nullable
|
||||
|
||||
sqlalchemy_type = TYPE_MAPPING.get(col_type)
|
||||
if not sqlalchemy_type:
|
||||
raise ValueError(f"Unsupported column type '{col_type}' for column '{col_name}'.")
|
||||
|
||||
sqlalchemy_columns.append(
|
||||
Column(col_name, sqlalchemy_type, primary_key=is_primary_key, nullable=is_nullable)
|
||||
)
|
||||
|
||||
if table not in self.metadata.tables:
|
||||
sqlalchemy_table = Table(table, self.metadata, *sqlalchemy_columns)
|
||||
else:
|
||||
sqlalchemy_table = self.metadata.tables[table]
|
||||
|
||||
engine = self.create_engine()
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(self.metadata.create_all, tables=[sqlalchemy_table], checkfirst=True)
|
||||
|
||||
async def insert(self, table: str, data: Mapping[str, Any] | Sequence[Mapping[str, Any]]) -> None:
|
||||
async with self.async_session() as session:
|
||||
await session.execute(self.metadata.tables[table].insert(), data)
|
||||
await session.commit()
|
||||
|
||||
async def upsert(
|
||||
self,
|
||||
table: str,
|
||||
data: Mapping[str, Any],
|
||||
conflict_columns: list[str],
|
||||
update_columns: list[str] | None = None,
|
||||
) -> None:
|
||||
table_obj = self.metadata.tables[table]
|
||||
dialect_insert = self._get_dialect_insert(table_obj)
|
||||
insert_stmt = dialect_insert.values(**data)
|
||||
|
||||
if update_columns is None:
|
||||
update_columns = [col for col in data.keys() if col not in conflict_columns]
|
||||
|
||||
update_mapping = {col: getattr(insert_stmt.excluded, col) for col in update_columns}
|
||||
conflict_cols = [table_obj.c[col] for col in conflict_columns]
|
||||
|
||||
stmt = insert_stmt.on_conflict_do_update(index_elements=conflict_cols, set_=update_mapping)
|
||||
|
||||
async with self.async_session() as session:
|
||||
await session.execute(stmt)
|
||||
await session.commit()
|
||||
|
||||
async def fetch_all(
|
||||
self,
|
||||
table: str,
|
||||
where: Mapping[str, Any] | None = None,
|
||||
where_sql: str | None = None,
|
||||
limit: int | None = None,
|
||||
order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None,
|
||||
cursor: tuple[str, str] | None = None,
|
||||
) -> PaginatedResponse:
|
||||
async with self.async_session() as session:
|
||||
table_obj = self.metadata.tables[table]
|
||||
query = select(table_obj)
|
||||
|
||||
if where:
|
||||
for key, value in where.items():
|
||||
query = query.where(_build_where_expr(table_obj.c[key], value))
|
||||
|
||||
if where_sql:
|
||||
query = query.where(text(where_sql))
|
||||
|
||||
# Handle cursor-based pagination
|
||||
if cursor:
|
||||
# Validate cursor tuple format
|
||||
if not isinstance(cursor, tuple) or len(cursor) != 2:
|
||||
raise ValueError(f"Cursor must be a tuple of (key_column, cursor_id), got: {cursor}")
|
||||
|
||||
# Require order_by for cursor pagination
|
||||
if not order_by:
|
||||
raise ValueError("order_by is required when using cursor pagination")
|
||||
|
||||
# Only support single-column ordering for cursor pagination
|
||||
if len(order_by) != 1:
|
||||
raise ValueError(
|
||||
f"Cursor pagination only supports single-column ordering, got {len(order_by)} columns"
|
||||
)
|
||||
|
||||
cursor_key_column, cursor_id = cursor
|
||||
order_column, order_direction = order_by[0]
|
||||
|
||||
# Verify cursor_key_column exists
|
||||
if cursor_key_column not in table_obj.c:
|
||||
raise ValueError(f"Cursor key column '{cursor_key_column}' not found in table '{table}'")
|
||||
|
||||
# Get cursor value for the order column
|
||||
cursor_query = select(table_obj.c[order_column]).where(table_obj.c[cursor_key_column] == cursor_id)
|
||||
cursor_result = await session.execute(cursor_query)
|
||||
cursor_row = cursor_result.fetchone()
|
||||
|
||||
if not cursor_row:
|
||||
raise ValueError(f"Record with {cursor_key_column}='{cursor_id}' not found in table '{table}'")
|
||||
|
||||
cursor_value = cursor_row[0]
|
||||
|
||||
# Apply cursor condition based on sort direction
|
||||
if order_direction == "desc":
|
||||
query = query.where(table_obj.c[order_column] < cursor_value)
|
||||
else:
|
||||
query = query.where(table_obj.c[order_column] > cursor_value)
|
||||
|
||||
# Apply ordering
|
||||
if order_by:
|
||||
if not isinstance(order_by, list):
|
||||
raise ValueError(
|
||||
f"order_by must be a list of tuples (column, order={['asc', 'desc']}), got {order_by}"
|
||||
)
|
||||
for order in order_by:
|
||||
if not isinstance(order, tuple):
|
||||
raise ValueError(
|
||||
f"order_by must be a list of tuples (column, order={['asc', 'desc']}), got {order_by}"
|
||||
)
|
||||
name, order_type = order
|
||||
if name not in table_obj.c:
|
||||
raise ValueError(f"Column '{name}' not found in table '{table}'")
|
||||
if order_type == "asc":
|
||||
query = query.order_by(table_obj.c[name].asc())
|
||||
elif order_type == "desc":
|
||||
query = query.order_by(table_obj.c[name].desc())
|
||||
else:
|
||||
raise ValueError(f"Invalid order '{order_type}' for column '{name}'")
|
||||
|
||||
# Fetch limit + 1 to determine has_more
|
||||
fetch_limit = limit
|
||||
if limit:
|
||||
fetch_limit = limit + 1
|
||||
|
||||
if fetch_limit:
|
||||
query = query.limit(fetch_limit)
|
||||
|
||||
result = await session.execute(query)
|
||||
# Iterate directly - if no rows, list comprehension yields empty list
|
||||
rows = [dict(row._mapping) for row in result]
|
||||
|
||||
# Always return pagination result
|
||||
has_more = False
|
||||
if limit and len(rows) > limit:
|
||||
has_more = True
|
||||
rows = rows[:limit]
|
||||
|
||||
return PaginatedResponse(data=rows, has_more=has_more)
|
||||
|
||||
async def fetch_one(
|
||||
self,
|
||||
table: str,
|
||||
where: Mapping[str, Any] | None = None,
|
||||
where_sql: str | None = None,
|
||||
order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None,
|
||||
) -> dict[str, Any] | None:
|
||||
result = await self.fetch_all(table, where, where_sql, limit=1, order_by=order_by)
|
||||
if not result.data:
|
||||
return None
|
||||
return result.data[0]
|
||||
|
||||
async def update(
|
||||
self,
|
||||
table: str,
|
||||
data: Mapping[str, Any],
|
||||
where: Mapping[str, Any],
|
||||
) -> None:
|
||||
if not where:
|
||||
raise ValueError("where is required for update")
|
||||
|
||||
async with self.async_session() as session:
|
||||
stmt = self.metadata.tables[table].update()
|
||||
for key, value in where.items():
|
||||
stmt = stmt.where(_build_where_expr(self.metadata.tables[table].c[key], value))
|
||||
await session.execute(stmt, data)
|
||||
await session.commit()
|
||||
|
||||
async def delete(self, table: str, where: Mapping[str, Any]) -> None:
|
||||
if not where:
|
||||
raise ValueError("where is required for delete")
|
||||
|
||||
async with self.async_session() as session:
|
||||
stmt = self.metadata.tables[table].delete()
|
||||
for key, value in where.items():
|
||||
stmt = stmt.where(_build_where_expr(self.metadata.tables[table].c[key], value))
|
||||
await session.execute(stmt)
|
||||
await session.commit()
|
||||
|
||||
async def add_column_if_not_exists(
|
||||
self,
|
||||
table: str,
|
||||
column_name: str,
|
||||
column_type: ColumnType,
|
||||
nullable: bool = True,
|
||||
) -> None:
|
||||
"""Add a column to an existing table if the column doesn't already exist."""
|
||||
engine = self.create_engine()
|
||||
|
||||
try:
|
||||
async with engine.begin() as conn:
|
||||
|
||||
def check_column_exists(sync_conn):
|
||||
inspector = inspect(sync_conn)
|
||||
|
||||
table_names = inspector.get_table_names()
|
||||
if table not in table_names:
|
||||
return False, False # table doesn't exist, column doesn't exist
|
||||
|
||||
existing_columns = inspector.get_columns(table)
|
||||
column_names = [col["name"] for col in existing_columns]
|
||||
|
||||
return True, column_name in column_names # table exists, column exists or not
|
||||
|
||||
table_exists, column_exists = await conn.run_sync(check_column_exists)
|
||||
if not table_exists or column_exists:
|
||||
return
|
||||
|
||||
sqlalchemy_type = TYPE_MAPPING.get(column_type)
|
||||
if not sqlalchemy_type:
|
||||
raise ValueError(f"Unsupported column type '{column_type}' for column '{column_name}'.")
|
||||
|
||||
# Create the ALTER TABLE statement
|
||||
# Note: We need to get the dialect-specific type name
|
||||
dialect = engine.dialect
|
||||
type_impl = sqlalchemy_type()
|
||||
compiled_type = type_impl.compile(dialect=dialect)
|
||||
|
||||
nullable_clause = "" if nullable else " NOT NULL"
|
||||
add_column_sql = text(f"ALTER TABLE {table} ADD COLUMN {column_name} {compiled_type}{nullable_clause}")
|
||||
|
||||
await conn.execute(add_column_sql)
|
||||
except Exception as e:
|
||||
# If any error occurs during migration, log it but don't fail
|
||||
# The table creation will handle adding the column
|
||||
logger.error(f"Error adding column {column_name} to table {table}: {e}")
|
||||
pass
|
||||
|
||||
def _get_dialect_insert(self, table: Table):
|
||||
if self._is_sqlite_backend:
|
||||
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
|
||||
|
||||
return sqlite_insert(table)
|
||||
else:
|
||||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||
|
||||
return pg_insert(table)
|
||||
87
src/llama_stack/core/storage/sqlstore/sqlstore.py
Normal file
87
src/llama_stack/core/storage/sqlstore/sqlstore.py
Normal file
|
|
@ -0,0 +1,87 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from threading import Lock
|
||||
from typing import Annotated, cast
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from llama_stack.core.storage.datatypes import (
|
||||
PostgresSqlStoreConfig,
|
||||
SqliteSqlStoreConfig,
|
||||
SqlStoreReference,
|
||||
StorageBackendConfig,
|
||||
StorageBackendType,
|
||||
)
|
||||
from llama_stack_api.internal.sqlstore import SqlStore
|
||||
|
||||
sql_store_pip_packages = ["sqlalchemy[asyncio]", "aiosqlite", "asyncpg"]
|
||||
|
||||
_SQLSTORE_BACKENDS: dict[str, StorageBackendConfig] = {}
|
||||
_SQLSTORE_INSTANCES: dict[str, SqlStore] = {}
|
||||
_SQLSTORE_LOCKS: dict[str, Lock] = {}
|
||||
|
||||
|
||||
SqlStoreConfig = Annotated[
|
||||
SqliteSqlStoreConfig | PostgresSqlStoreConfig,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
|
||||
|
||||
def get_pip_packages(store_config: dict | SqlStoreConfig) -> list[str]:
|
||||
"""Get pip packages for SQL store config, handling both dict and object cases."""
|
||||
if isinstance(store_config, dict):
|
||||
store_type = store_config.get("type")
|
||||
if store_type == StorageBackendType.SQL_SQLITE.value:
|
||||
return SqliteSqlStoreConfig.pip_packages()
|
||||
elif store_type == StorageBackendType.SQL_POSTGRES.value:
|
||||
return PostgresSqlStoreConfig.pip_packages()
|
||||
else:
|
||||
raise ValueError(f"Unknown SQL store type: {store_type}")
|
||||
else:
|
||||
return store_config.pip_packages()
|
||||
|
||||
|
||||
def sqlstore_impl(reference: SqlStoreReference) -> SqlStore:
|
||||
backend_name = reference.backend
|
||||
|
||||
backend_config = _SQLSTORE_BACKENDS.get(backend_name)
|
||||
if backend_config is None:
|
||||
raise ValueError(
|
||||
f"Unknown SQL store backend '{backend_name}'. Registered backends: {sorted(_SQLSTORE_BACKENDS)}"
|
||||
)
|
||||
|
||||
existing = _SQLSTORE_INSTANCES.get(backend_name)
|
||||
if existing:
|
||||
return existing
|
||||
|
||||
lock = _SQLSTORE_LOCKS.setdefault(backend_name, Lock())
|
||||
with lock:
|
||||
existing = _SQLSTORE_INSTANCES.get(backend_name)
|
||||
if existing:
|
||||
return existing
|
||||
|
||||
if isinstance(backend_config, SqliteSqlStoreConfig | PostgresSqlStoreConfig):
|
||||
from .sqlalchemy_sqlstore import SqlAlchemySqlStoreImpl
|
||||
|
||||
config = cast(SqliteSqlStoreConfig | PostgresSqlStoreConfig, backend_config).model_copy()
|
||||
instance = SqlAlchemySqlStoreImpl(config)
|
||||
_SQLSTORE_INSTANCES[backend_name] = instance
|
||||
return instance
|
||||
else:
|
||||
raise ValueError(f"Unknown sqlstore type {backend_config.type}")
|
||||
|
||||
|
||||
def register_sqlstore_backends(backends: dict[str, StorageBackendConfig]) -> None:
|
||||
"""Register the set of available SQL store backends for reference resolution."""
|
||||
global _SQLSTORE_BACKENDS
|
||||
global _SQLSTORE_INSTANCES
|
||||
|
||||
_SQLSTORE_BACKENDS.clear()
|
||||
_SQLSTORE_INSTANCES.clear()
|
||||
_SQLSTORE_LOCKS.clear()
|
||||
for name, cfg in backends.items():
|
||||
_SQLSTORE_BACKENDS[name] = cfg
|
||||
|
|
@ -12,8 +12,8 @@ import pydantic
|
|||
|
||||
from llama_stack.core.datatypes import RoutableObjectWithProvider
|
||||
from llama_stack.core.storage.datatypes import KVStoreReference
|
||||
from llama_stack.core.storage.kvstore import KVStore, kvstore_impl
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.kvstore import KVStore, kvstore_impl
|
||||
|
||||
logger = get_logger(__name__, category="core::registry")
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue