diff --git a/src/llama_stack/core/conversations/conversations.py b/src/llama_stack/core/conversations/conversations.py index 4cf5a82ee..21598f9fa 100644 --- a/src/llama_stack/core/conversations/conversations.py +++ b/src/llama_stack/core/conversations/conversations.py @@ -12,9 +12,8 @@ from pydantic import BaseModel, TypeAdapter from llama_stack.core.datatypes import AccessRule, StackRunConfig from llama_stack.log import get_logger -from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType -from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore -from llama_stack.providers.utils.sqlstore.sqlstore import sqlstore_impl +from llama_stack.core.storage.sqlstore.authorized_sqlstore import AuthorizedSqlStore +from llama_stack.core.storage.sqlstore.sqlstore import sqlstore_impl from llama_stack_api import ( Conversation, ConversationDeletedResource, @@ -25,6 +24,7 @@ from llama_stack_api import ( Conversations, Metadata, ) +from llama_stack_api.internal.sqlstore import ColumnDefinition, ColumnType logger = get_logger(name=__name__, category="openai_conversations") diff --git a/src/llama_stack/core/prompts/prompts.py b/src/llama_stack/core/prompts/prompts.py index 9f532c1cd..ff67ad138 100644 --- a/src/llama_stack/core/prompts/prompts.py +++ b/src/llama_stack/core/prompts/prompts.py @@ -10,7 +10,7 @@ from typing import Any from pydantic import BaseModel from llama_stack.core.datatypes import StackRunConfig -from llama_stack.providers.utils.kvstore import KVStore, kvstore_impl +from llama_stack.core.storage.kvstore import KVStore, kvstore_impl from llama_stack_api import ListPromptsResponse, Prompt, Prompts diff --git a/src/llama_stack/core/server/quota.py b/src/llama_stack/core/server/quota.py index 689f0e4c3..4688462b4 100644 --- a/src/llama_stack/core/server/quota.py +++ b/src/llama_stack/core/server/quota.py @@ -12,8 +12,8 @@ from starlette.types import ASGIApp, Receive, Scope, Send from llama_stack.core.storage.datatypes import KVStoreReference, StorageBackendType from llama_stack.log import get_logger -from llama_stack.providers.utils.kvstore.api import KVStore -from llama_stack.providers.utils.kvstore.kvstore import _KVSTORE_BACKENDS, kvstore_impl +from llama_stack.core.storage.kvstore.kvstore import _KVSTORE_BACKENDS, kvstore_impl +from llama_stack_api.internal.kvstore import KVStore logger = get_logger(name=__name__, category="core::server") diff --git a/src/llama_stack/core/stack.py b/src/llama_stack/core/stack.py index 00d990cb1..8ba1f2afd 100644 --- a/src/llama_stack/core/stack.py +++ b/src/llama_stack/core/stack.py @@ -385,8 +385,8 @@ def _initialize_storage(run_config: StackRunConfig): else: raise ValueError(f"Unknown storage backend type: {type}") - from llama_stack.providers.utils.kvstore.kvstore import register_kvstore_backends - from llama_stack.providers.utils.sqlstore.sqlstore import register_sqlstore_backends + from llama_stack.core.storage.kvstore.kvstore import register_kvstore_backends + from llama_stack.core.storage.sqlstore.sqlstore import register_sqlstore_backends register_kvstore_backends(kv_backends) register_sqlstore_backends(sql_backends) diff --git a/src/llama_stack/core/storage/kvstore/__init__.py b/src/llama_stack/core/storage/kvstore/__init__.py new file mode 100644 index 000000000..470a75d2d --- /dev/null +++ b/src/llama_stack/core/storage/kvstore/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .kvstore import * # noqa: F401, F403 diff --git a/src/llama_stack/core/storage/kvstore/api.py b/src/llama_stack/core/storage/kvstore/api.py new file mode 100644 index 000000000..83d28b87f --- /dev/null +++ b/src/llama_stack/core/storage/kvstore/api.py @@ -0,0 +1,9 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from llama_stack_api.internal.kvstore import KVStore + +__all__ = ["KVStore"] diff --git a/src/llama_stack/core/storage/kvstore/config.py b/src/llama_stack/core/storage/kvstore/config.py new file mode 100644 index 000000000..c0582abc4 --- /dev/null +++ b/src/llama_stack/core/storage/kvstore/config.py @@ -0,0 +1,39 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Annotated + +from pydantic import Field + +from llama_stack.core.storage.datatypes import ( + MongoDBKVStoreConfig, + PostgresKVStoreConfig, + RedisKVStoreConfig, + SqliteKVStoreConfig, + StorageBackendType, +) + +KVStoreConfig = Annotated[ + RedisKVStoreConfig | SqliteKVStoreConfig | PostgresKVStoreConfig | MongoDBKVStoreConfig, Field(discriminator="type") +] + + +def get_pip_packages(store_config: dict | KVStoreConfig) -> list[str]: + """Get pip packages for KV store config, handling both dict and object cases.""" + if isinstance(store_config, dict): + store_type = store_config.get("type") + if store_type == StorageBackendType.KV_SQLITE.value: + return SqliteKVStoreConfig.pip_packages() + elif store_type == StorageBackendType.KV_POSTGRES.value: + return PostgresKVStoreConfig.pip_packages() + elif store_type == StorageBackendType.KV_REDIS.value: + return RedisKVStoreConfig.pip_packages() + elif store_type == StorageBackendType.KV_MONGODB.value: + return MongoDBKVStoreConfig.pip_packages() + else: + raise ValueError(f"Unknown KV store type: {store_type}") + else: + return store_config.pip_packages() diff --git a/src/llama_stack/core/storage/kvstore/kvstore.py b/src/llama_stack/core/storage/kvstore/kvstore.py new file mode 100644 index 000000000..5b8d77102 --- /dev/null +++ b/src/llama_stack/core/storage/kvstore/kvstore.py @@ -0,0 +1,118 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from __future__ import annotations + +import asyncio +from collections import defaultdict + +from llama_stack.core.storage.datatypes import KVStoreReference, StorageBackendConfig, StorageBackendType + +from .api import KVStore +from .config import KVStoreConfig + + +def kvstore_dependencies(): + """ + Returns all possible kvstore dependencies for registry/provider specifications. + + NOTE: For specific kvstore implementations, use config.pip_packages instead. + This function returns the union of all dependencies for cases where the specific + kvstore type is not known at declaration time (e.g., provider registries). + """ + return ["aiosqlite", "psycopg2-binary", "redis", "pymongo"] + + +class InmemoryKVStoreImpl(KVStore): + def __init__(self): + self._store = {} + + async def initialize(self) -> None: + pass + + async def get(self, key: str) -> str | None: + return self._store.get(key) + + async def set(self, key: str, value: str) -> None: + self._store[key] = value + + async def values_in_range(self, start_key: str, end_key: str) -> list[str]: + return [self._store[key] for key in self._store.keys() if key >= start_key and key < end_key] + + async def keys_in_range(self, start_key: str, end_key: str) -> list[str]: + """Get all keys in the given range.""" + return [key for key in self._store.keys() if key >= start_key and key < end_key] + + async def delete(self, key: str) -> None: + del self._store[key] + + +_KVSTORE_BACKENDS: dict[str, KVStoreConfig] = {} +_KVSTORE_INSTANCES: dict[tuple[str, str], KVStore] = {} +_KVSTORE_LOCKS: defaultdict[tuple[str, str], asyncio.Lock] = defaultdict(asyncio.Lock) + + +def register_kvstore_backends(backends: dict[str, StorageBackendConfig]) -> None: + """Register the set of available KV store backends for reference resolution.""" + global _KVSTORE_BACKENDS + global _KVSTORE_INSTANCES + global _KVSTORE_LOCKS + + _KVSTORE_BACKENDS.clear() + _KVSTORE_INSTANCES.clear() + _KVSTORE_LOCKS.clear() + for name, cfg in backends.items(): + _KVSTORE_BACKENDS[name] = cfg + + +async def kvstore_impl(reference: KVStoreReference) -> KVStore: + backend_name = reference.backend + cache_key = (backend_name, reference.namespace) + + existing = _KVSTORE_INSTANCES.get(cache_key) + if existing: + return existing + + backend_config = _KVSTORE_BACKENDS.get(backend_name) + if backend_config is None: + raise ValueError(f"Unknown KVStore backend '{backend_name}'. Registered backends: {sorted(_KVSTORE_BACKENDS)}") + + lock = _KVSTORE_LOCKS[cache_key] + async with lock: + existing = _KVSTORE_INSTANCES.get(cache_key) + if existing: + return existing + + config = backend_config.model_copy() + config.namespace = reference.namespace + + if config.type == StorageBackendType.KV_REDIS.value: + from .redis import RedisKVStoreImpl + + impl = RedisKVStoreImpl(config) + elif config.type == StorageBackendType.KV_SQLITE.value: + from .sqlite import SqliteKVStoreImpl + + impl = SqliteKVStoreImpl(config) + elif config.type == StorageBackendType.KV_POSTGRES.value: + from .postgres import PostgresKVStoreImpl + + impl = PostgresKVStoreImpl(config) + elif config.type == StorageBackendType.KV_MONGODB.value: + from .mongodb import MongoDBKVStoreImpl + + impl = MongoDBKVStoreImpl(config) + else: + raise ValueError(f"Unknown kvstore type {config.type}") + + await impl.initialize() + _KVSTORE_INSTANCES[cache_key] = impl + return impl diff --git a/src/llama_stack/core/storage/kvstore/mongodb/__init__.py b/src/llama_stack/core/storage/kvstore/mongodb/__init__.py new file mode 100644 index 000000000..a42fee76f --- /dev/null +++ b/src/llama_stack/core/storage/kvstore/mongodb/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .mongodb import MongoDBKVStoreImpl + +__all__ = ["MongoDBKVStoreImpl"] diff --git a/src/llama_stack/core/storage/kvstore/mongodb/mongodb.py b/src/llama_stack/core/storage/kvstore/mongodb/mongodb.py new file mode 100644 index 000000000..51df6d613 --- /dev/null +++ b/src/llama_stack/core/storage/kvstore/mongodb/mongodb.py @@ -0,0 +1,85 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from datetime import datetime + +from pymongo import AsyncMongoClient +from pymongo.asynchronous.collection import AsyncCollection + +from llama_stack.log import get_logger +from llama_stack.core.storage.kvstore import KVStore + +from ..config import MongoDBKVStoreConfig + +log = get_logger(name=__name__, category="providers::utils") + + +class MongoDBKVStoreImpl(KVStore): + def __init__(self, config: MongoDBKVStoreConfig): + self.config = config + self.conn: AsyncMongoClient | None = None + + @property + def collection(self) -> AsyncCollection: + if self.conn is None: + raise RuntimeError("MongoDB connection is not initialized") + return self.conn[self.config.db][self.config.collection_name] + + async def initialize(self) -> None: + try: + # Pass parameters explicitly to satisfy mypy - AsyncMongoClient doesn't accept **dict + self.conn = AsyncMongoClient( + host=self.config.host if self.config.host is not None else None, + port=self.config.port if self.config.port is not None else None, + username=self.config.user if self.config.user is not None else None, + password=self.config.password if self.config.password is not None else None, + ) + except Exception as e: + log.exception("Could not connect to MongoDB database server") + raise RuntimeError("Could not connect to MongoDB database server") from e + + def _namespaced_key(self, key: str) -> str: + if not self.config.namespace: + return key + return f"{self.config.namespace}:{key}" + + async def set(self, key: str, value: str, expiration: datetime | None = None) -> None: + key = self._namespaced_key(key) + update_query = {"$set": {"value": value, "expiration": expiration}} + await self.collection.update_one({"key": key}, update_query, upsert=True) + + async def get(self, key: str) -> str | None: + key = self._namespaced_key(key) + query = {"key": key} + result = await self.collection.find_one(query, {"value": 1, "_id": 0}) + return result["value"] if result else None + + async def delete(self, key: str) -> None: + key = self._namespaced_key(key) + await self.collection.delete_one({"key": key}) + + async def values_in_range(self, start_key: str, end_key: str) -> list[str]: + start_key = self._namespaced_key(start_key) + end_key = self._namespaced_key(end_key) + query = { + "key": {"$gte": start_key, "$lt": end_key}, + } + cursor = self.collection.find(query, {"value": 1, "_id": 0}).sort("key", 1) + result = [] + async for doc in cursor: + result.append(doc["value"]) + return result + + async def keys_in_range(self, start_key: str, end_key: str) -> list[str]: + start_key = self._namespaced_key(start_key) + end_key = self._namespaced_key(end_key) + query = {"key": {"$gte": start_key, "$lt": end_key}} + cursor = self.collection.find(query, {"key": 1, "_id": 0}).sort("key", 1) + # AsyncCursor requires async iteration + result = [] + async for doc in cursor: + result.append(doc["key"]) + return result diff --git a/src/llama_stack/core/storage/kvstore/postgres/__init__.py b/src/llama_stack/core/storage/kvstore/postgres/__init__.py new file mode 100644 index 000000000..efbf6299d --- /dev/null +++ b/src/llama_stack/core/storage/kvstore/postgres/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .postgres import PostgresKVStoreImpl # noqa: F401 F403 diff --git a/src/llama_stack/core/storage/kvstore/postgres/postgres.py b/src/llama_stack/core/storage/kvstore/postgres/postgres.py new file mode 100644 index 000000000..56d6dbb48 --- /dev/null +++ b/src/llama_stack/core/storage/kvstore/postgres/postgres.py @@ -0,0 +1,114 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from datetime import datetime + +import psycopg2 +from psycopg2.extras import DictCursor + +from llama_stack.log import get_logger + +from ..api import KVStore +from ..config import PostgresKVStoreConfig + +log = get_logger(name=__name__, category="providers::utils") + + +class PostgresKVStoreImpl(KVStore): + def __init__(self, config: PostgresKVStoreConfig): + self.config = config + self.conn = None + self.cursor = None + + async def initialize(self) -> None: + try: + self.conn = psycopg2.connect( + host=self.config.host, + port=self.config.port, + database=self.config.db, + user=self.config.user, + password=self.config.password, + sslmode=self.config.ssl_mode, + sslrootcert=self.config.ca_cert_path, + ) + self.conn.autocommit = True + self.cursor = self.conn.cursor(cursor_factory=DictCursor) + + # Create table if it doesn't exist + self.cursor.execute( + f""" + CREATE TABLE IF NOT EXISTS {self.config.table_name} ( + key TEXT PRIMARY KEY, + value TEXT, + expiration TIMESTAMP + ) + """ + ) + except Exception as e: + log.exception("Could not connect to PostgreSQL database server") + raise RuntimeError("Could not connect to PostgreSQL database server") from e + + def _namespaced_key(self, key: str) -> str: + if not self.config.namespace: + return key + return f"{self.config.namespace}:{key}" + + async def set(self, key: str, value: str, expiration: datetime | None = None) -> None: + key = self._namespaced_key(key) + self.cursor.execute( + f""" + INSERT INTO {self.config.table_name} (key, value, expiration) + VALUES (%s, %s, %s) + ON CONFLICT (key) DO UPDATE + SET value = EXCLUDED.value, expiration = EXCLUDED.expiration + """, + (key, value, expiration), + ) + + async def get(self, key: str) -> str | None: + key = self._namespaced_key(key) + self.cursor.execute( + f""" + SELECT value FROM {self.config.table_name} + WHERE key = %s + AND (expiration IS NULL OR expiration > NOW()) + """, + (key,), + ) + result = self.cursor.fetchone() + return result[0] if result else None + + async def delete(self, key: str) -> None: + key = self._namespaced_key(key) + self.cursor.execute( + f"DELETE FROM {self.config.table_name} WHERE key = %s", + (key,), + ) + + async def values_in_range(self, start_key: str, end_key: str) -> list[str]: + start_key = self._namespaced_key(start_key) + end_key = self._namespaced_key(end_key) + + self.cursor.execute( + f""" + SELECT value FROM {self.config.table_name} + WHERE key >= %s AND key < %s + AND (expiration IS NULL OR expiration > NOW()) + ORDER BY key + """, + (start_key, end_key), + ) + return [row[0] for row in self.cursor.fetchall()] + + async def keys_in_range(self, start_key: str, end_key: str) -> list[str]: + start_key = self._namespaced_key(start_key) + end_key = self._namespaced_key(end_key) + + self.cursor.execute( + f"SELECT key FROM {self.config.table_name} WHERE key >= %s AND key < %s", + (start_key, end_key), + ) + return [row[0] for row in self.cursor.fetchall()] diff --git a/src/llama_stack/core/storage/kvstore/redis/__init__.py b/src/llama_stack/core/storage/kvstore/redis/__init__.py new file mode 100644 index 000000000..94693ca43 --- /dev/null +++ b/src/llama_stack/core/storage/kvstore/redis/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .redis import RedisKVStoreImpl # noqa: F401 diff --git a/src/llama_stack/core/storage/kvstore/redis/redis.py b/src/llama_stack/core/storage/kvstore/redis/redis.py new file mode 100644 index 000000000..3d2d956c3 --- /dev/null +++ b/src/llama_stack/core/storage/kvstore/redis/redis.py @@ -0,0 +1,76 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from datetime import datetime + +from redis.asyncio import Redis + +from ..api import KVStore +from ..config import RedisKVStoreConfig + + +class RedisKVStoreImpl(KVStore): + def __init__(self, config: RedisKVStoreConfig): + self.config = config + + async def initialize(self) -> None: + self.redis = Redis.from_url(self.config.url) + + def _namespaced_key(self, key: str) -> str: + if not self.config.namespace: + return key + return f"{self.config.namespace}:{key}" + + async def set(self, key: str, value: str, expiration: datetime | None = None) -> None: + key = self._namespaced_key(key) + await self.redis.set(key, value) + if expiration: + await self.redis.expireat(key, expiration) + + async def get(self, key: str) -> str | None: + key = self._namespaced_key(key) + value = await self.redis.get(key) + if value is None: + return None + await self.redis.ttl(key) + return value + + async def delete(self, key: str) -> None: + key = self._namespaced_key(key) + await self.redis.delete(key) + + async def values_in_range(self, start_key: str, end_key: str) -> list[str]: + start_key = self._namespaced_key(start_key) + end_key = self._namespaced_key(end_key) + cursor = 0 + pattern = start_key + "*" # Match all keys starting with start_key prefix + matching_keys = [] + while True: + cursor, keys = await self.redis.scan(cursor, match=pattern, count=1000) + + for key in keys: + key_str = key.decode("utf-8") if isinstance(key, bytes) else key + if start_key <= key_str <= end_key: + matching_keys.append(key) + + if cursor == 0: + break + + # Then fetch all values in a single MGET call + if matching_keys: + values = await self.redis.mget(matching_keys) + return [ + value.decode("utf-8") if isinstance(value, bytes) else value for value in values if value is not None + ] + + return [] + + async def keys_in_range(self, start_key: str, end_key: str) -> list[str]: + """Get all keys in the given range.""" + matching_keys = await self.redis.zrangebylex(self.namespace, f"[{start_key}", f"[{end_key}") + if not matching_keys: + return [] + return [k.decode("utf-8") for k in matching_keys] diff --git a/src/llama_stack/core/storage/kvstore/sqlite/__init__.py b/src/llama_stack/core/storage/kvstore/sqlite/__init__.py new file mode 100644 index 000000000..03bc53c24 --- /dev/null +++ b/src/llama_stack/core/storage/kvstore/sqlite/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .sqlite import SqliteKVStoreImpl # noqa: F401 diff --git a/src/llama_stack/core/storage/kvstore/sqlite/config.py b/src/llama_stack/core/storage/kvstore/sqlite/config.py new file mode 100644 index 000000000..0f8fa0a95 --- /dev/null +++ b/src/llama_stack/core/storage/kvstore/sqlite/config.py @@ -0,0 +1,20 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from pydantic import BaseModel, Field + +from llama_stack_api import json_schema_type + + +@json_schema_type +class SqliteControlPlaneConfig(BaseModel): + db_path: str = Field( + description="File path for the sqlite database", + ) + table_name: str = Field( + default="llamastack_control_plane", + description="Table into which all the keys will be placed", + ) diff --git a/src/llama_stack/core/storage/kvstore/sqlite/sqlite.py b/src/llama_stack/core/storage/kvstore/sqlite/sqlite.py new file mode 100644 index 000000000..a9a7a1304 --- /dev/null +++ b/src/llama_stack/core/storage/kvstore/sqlite/sqlite.py @@ -0,0 +1,174 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import os +from datetime import datetime + +import aiosqlite + +from llama_stack.log import get_logger + +from ..api import KVStore +from ..config import SqliteKVStoreConfig + +logger = get_logger(name=__name__, category="providers::utils") + + +class SqliteKVStoreImpl(KVStore): + def __init__(self, config: SqliteKVStoreConfig): + self.db_path = config.db_path + self.table_name = "kvstore" + self._conn: aiosqlite.Connection | None = None + + def __str__(self): + return f"SqliteKVStoreImpl(db_path={self.db_path}, table_name={self.table_name})" + + def _is_memory_db(self) -> bool: + """Check if this is an in-memory database.""" + return self.db_path == ":memory:" or "mode=memory" in self.db_path + + async def initialize(self): + # Skip directory creation for in-memory databases and file: URIs + if not self._is_memory_db() and not self.db_path.startswith("file:"): + db_dir = os.path.dirname(self.db_path) + if db_dir: # Only create if there's a directory component + os.makedirs(db_dir, exist_ok=True) + + # Only use persistent connection for in-memory databases + # File-based databases use connection-per-operation to avoid hangs + if self._is_memory_db(): + self._conn = await aiosqlite.connect(self.db_path) + await self._conn.execute( + f""" + CREATE TABLE IF NOT EXISTS {self.table_name} ( + key TEXT PRIMARY KEY, + value TEXT, + expiration TIMESTAMP + ) + """ + ) + await self._conn.commit() + else: + # For file-based databases, just create the table + async with aiosqlite.connect(self.db_path) as db: + await db.execute( + f""" + CREATE TABLE IF NOT EXISTS {self.table_name} ( + key TEXT PRIMARY KEY, + value TEXT, + expiration TIMESTAMP + ) + """ + ) + await db.commit() + + async def shutdown(self): + """Close the persistent connection (only for in-memory databases).""" + if self._conn: + await self._conn.close() + self._conn = None + + async def set(self, key: str, value: str, expiration: datetime | None = None) -> None: + if self._conn: + # In-memory database with persistent connection + await self._conn.execute( + f"INSERT OR REPLACE INTO {self.table_name} (key, value, expiration) VALUES (?, ?, ?)", + (key, value, expiration), + ) + await self._conn.commit() + else: + # File-based database with connection per operation + async with aiosqlite.connect(self.db_path) as db: + await db.execute( + f"INSERT OR REPLACE INTO {self.table_name} (key, value, expiration) VALUES (?, ?, ?)", + (key, value, expiration), + ) + await db.commit() + + async def get(self, key: str) -> str | None: + if self._conn: + # In-memory database with persistent connection + async with self._conn.execute( + f"SELECT value, expiration FROM {self.table_name} WHERE key = ?", (key,) + ) as cursor: + row = await cursor.fetchone() + if row is None: + return None + value, expiration = row + if not isinstance(value, str): + logger.warning(f"Expected string value for key {key}, got {type(value)}, returning None") + return None + return value + else: + # File-based database with connection per operation + async with aiosqlite.connect(self.db_path) as db: + async with db.execute( + f"SELECT value, expiration FROM {self.table_name} WHERE key = ?", (key,) + ) as cursor: + row = await cursor.fetchone() + if row is None: + return None + value, expiration = row + if not isinstance(value, str): + logger.warning(f"Expected string value for key {key}, got {type(value)}, returning None") + return None + return value + + async def delete(self, key: str) -> None: + if self._conn: + # In-memory database with persistent connection + await self._conn.execute(f"DELETE FROM {self.table_name} WHERE key = ?", (key,)) + await self._conn.commit() + else: + # File-based database with connection per operation + async with aiosqlite.connect(self.db_path) as db: + await db.execute(f"DELETE FROM {self.table_name} WHERE key = ?", (key,)) + await db.commit() + + async def values_in_range(self, start_key: str, end_key: str) -> list[str]: + if self._conn: + # In-memory database with persistent connection + async with self._conn.execute( + f"SELECT key, value, expiration FROM {self.table_name} WHERE key >= ? AND key <= ?", + (start_key, end_key), + ) as cursor: + result = [] + async for row in cursor: + _, value, _ = row + result.append(value) + return result + else: + # File-based database with connection per operation + async with aiosqlite.connect(self.db_path) as db: + async with db.execute( + f"SELECT key, value, expiration FROM {self.table_name} WHERE key >= ? AND key <= ?", + (start_key, end_key), + ) as cursor: + result = [] + async for row in cursor: + _, value, _ = row + result.append(value) + return result + + async def keys_in_range(self, start_key: str, end_key: str) -> list[str]: + """Get all keys in the given range.""" + if self._conn: + # In-memory database with persistent connection + cursor = await self._conn.execute( + f"SELECT key FROM {self.table_name} WHERE key >= ? AND key <= ?", + (start_key, end_key), + ) + rows = await cursor.fetchall() + return [row[0] for row in rows] + else: + # File-based database with connection per operation + async with aiosqlite.connect(self.db_path) as db: + cursor = await db.execute( + f"SELECT key FROM {self.table_name} WHERE key >= ? AND key <= ?", + (start_key, end_key), + ) + rows = await cursor.fetchall() + return [row[0] for row in rows] diff --git a/src/llama_stack/core/storage/sqlstore/__init__.py b/src/llama_stack/core/storage/sqlstore/__init__.py new file mode 100644 index 000000000..488782f28 --- /dev/null +++ b/src/llama_stack/core/storage/sqlstore/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .sqlstore import * # noqa: F401,F403 diff --git a/src/llama_stack/core/storage/sqlstore/api.py b/src/llama_stack/core/storage/sqlstore/api.py new file mode 100644 index 000000000..9ece28d2c --- /dev/null +++ b/src/llama_stack/core/storage/sqlstore/api.py @@ -0,0 +1,10 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from llama_stack_api import PaginatedResponse +from llama_stack_api.internal.sqlstore import ColumnDefinition, ColumnType, SqlStore + +__all__ = ["ColumnDefinition", "ColumnType", "SqlStore", "PaginatedResponse"] diff --git a/src/llama_stack/core/storage/sqlstore/authorized_sqlstore.py b/src/llama_stack/core/storage/sqlstore/authorized_sqlstore.py new file mode 100644 index 000000000..ba95dd120 --- /dev/null +++ b/src/llama_stack/core/storage/sqlstore/authorized_sqlstore.py @@ -0,0 +1,336 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from collections.abc import Mapping, Sequence +from typing import Any, Literal + +from llama_stack.core.access_control.access_control import default_policy, is_action_allowed +from llama_stack.core.access_control.conditions import ProtectedResource +from llama_stack.core.access_control.datatypes import AccessRule, Action, Scope +from llama_stack.core.datatypes import User +from llama_stack.core.request_headers import get_authenticated_user +from llama_stack.core.storage.datatypes import StorageBackendType +from llama_stack.log import get_logger + +from .api import ColumnDefinition, ColumnType, PaginatedResponse, SqlStore + +logger = get_logger(name=__name__, category="providers::utils") + +# Hardcoded copy of the default policy that our SQL filtering implements +# WARNING: If default_policy() changes, this constant must be updated accordingly +# or SQL filtering will fall back to conservative mode (safe but less performant) +# +# This policy represents: "Permit all actions when user is in owners list for ALL attribute categories" +# The corresponding SQL logic is implemented in _build_default_policy_where_clause(): +# - Public records (no access_attributes) are always accessible +# - Records with access_attributes require user to match ALL categories that exist in the resource +# - Missing categories in the resource are treated as "no restriction" (allow) +# - Within each category, user needs ANY matching value (OR logic) +# - Between categories, user needs ALL categories to match (AND logic) +SQL_OPTIMIZED_POLICY = [ + AccessRule( + permit=Scope(actions=list(Action)), + when=["user in owners roles", "user in owners teams", "user in owners projects", "user in owners namespaces"], + ), +] + + +def _enhance_item_with_access_control(item: Mapping[str, Any], current_user: User | None) -> Mapping[str, Any]: + """Add access control attributes to a data item.""" + enhanced = dict(item) + if current_user: + enhanced["owner_principal"] = current_user.principal + enhanced["access_attributes"] = current_user.attributes + else: + # IMPORTANT: Use empty string and null value (not None) to match public access filter + # The public access filter in _get_public_access_conditions() expects: + # - owner_principal = '' (empty string) + # - access_attributes = null (JSON null, which serializes to the string 'null') + # Setting them to None (SQL NULL) will cause rows to be filtered out on read. + enhanced["owner_principal"] = "" + enhanced["access_attributes"] = None # Pydantic/JSON will serialize this as JSON null + return enhanced + + +class SqlRecord(ProtectedResource): + def __init__(self, record_id: str, table_name: str, owner: User): + self.type = f"sql_record::{table_name}" + self.identifier = record_id + self.owner = owner + + +class AuthorizedSqlStore: + """ + Authorization layer for SqlStore that provides access control functionality. + + This class composes a base SqlStore and adds authorization methods that handle + access control policies, user attribute capture, and SQL filtering optimization. + """ + + def __init__(self, sql_store: SqlStore, policy: list[AccessRule]): + """ + Initialize the authorization layer. + + :param sql_store: Base SqlStore implementation to wrap + :param policy: Access control policy to use for authorization + """ + self.sql_store = sql_store + self.policy = policy + self._detect_database_type() + self._validate_sql_optimized_policy() + + def _detect_database_type(self) -> None: + """Detect the database type from the underlying SQL store.""" + if not hasattr(self.sql_store, "config"): + raise ValueError("SqlStore must have a config attribute to be used with AuthorizedSqlStore") + + self.database_type = self.sql_store.config.type.value + if self.database_type not in [StorageBackendType.SQL_POSTGRES.value, StorageBackendType.SQL_SQLITE.value]: + raise ValueError(f"Unsupported database type: {self.database_type}") + + def _validate_sql_optimized_policy(self) -> None: + """Validate that SQL_OPTIMIZED_POLICY matches the actual default_policy(). + + This ensures that if default_policy() changes, we detect the mismatch and + can update our SQL filtering logic accordingly. + """ + actual_default = default_policy() + + if SQL_OPTIMIZED_POLICY != actual_default: + logger.warning( + f"SQL_OPTIMIZED_POLICY does not match default_policy(). " + f"SQL filtering will use conservative mode. " + f"Expected: {SQL_OPTIMIZED_POLICY}, Got: {actual_default}", + ) + + async def create_table(self, table: str, schema: Mapping[str, ColumnType | ColumnDefinition]) -> None: + """Create a table with built-in access control support.""" + + enhanced_schema = dict(schema) + if "access_attributes" not in enhanced_schema: + enhanced_schema["access_attributes"] = ColumnType.JSON + if "owner_principal" not in enhanced_schema: + enhanced_schema["owner_principal"] = ColumnType.STRING + + await self.sql_store.create_table(table, enhanced_schema) + await self.sql_store.add_column_if_not_exists(table, "access_attributes", ColumnType.JSON) + await self.sql_store.add_column_if_not_exists(table, "owner_principal", ColumnType.STRING) + + async def insert(self, table: str, data: Mapping[str, Any] | Sequence[Mapping[str, Any]]) -> None: + """Insert a row or batch of rows with automatic access control attribute capture.""" + current_user = get_authenticated_user() + enhanced_data: Mapping[str, Any] | Sequence[Mapping[str, Any]] + if isinstance(data, Mapping): + enhanced_data = _enhance_item_with_access_control(data, current_user) + else: + enhanced_data = [_enhance_item_with_access_control(item, current_user) for item in data] + await self.sql_store.insert(table, enhanced_data) + + async def upsert( + self, + table: str, + data: Mapping[str, Any], + conflict_columns: list[str], + update_columns: list[str] | None = None, + ) -> None: + """Upsert a row with automatic access control attribute capture.""" + current_user = get_authenticated_user() + enhanced_data = _enhance_item_with_access_control(data, current_user) + await self.sql_store.upsert( + table=table, + data=enhanced_data, + conflict_columns=conflict_columns, + update_columns=update_columns, + ) + + async def fetch_all( + self, + table: str, + where: Mapping[str, Any] | None = None, + limit: int | None = None, + order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None, + cursor: tuple[str, str] | None = None, + ) -> PaginatedResponse: + """Fetch all rows with automatic access control filtering.""" + access_where = self._build_access_control_where_clause(self.policy) + rows = await self.sql_store.fetch_all( + table=table, + where=where, + where_sql=access_where, + limit=limit, + order_by=order_by, + cursor=cursor, + ) + + current_user = get_authenticated_user() + filtered_rows = [] + + for row in rows.data: + stored_access_attrs = row.get("access_attributes") + stored_owner_principal = row.get("owner_principal") or "" + + record_id = row.get("id", "unknown") + sql_record = SqlRecord( + str(record_id), table, User(principal=stored_owner_principal, attributes=stored_access_attrs) + ) + + if is_action_allowed(self.policy, Action.READ, sql_record, current_user): + filtered_rows.append(row) + + return PaginatedResponse( + data=filtered_rows, + has_more=rows.has_more, + ) + + async def fetch_one( + self, + table: str, + where: Mapping[str, Any] | None = None, + order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None, + ) -> dict[str, Any] | None: + """Fetch one row with automatic access control checking.""" + results = await self.fetch_all( + table=table, + where=where, + limit=1, + order_by=order_by, + ) + + return results.data[0] if results.data else None + + async def update(self, table: str, data: Mapping[str, Any], where: Mapping[str, Any]) -> None: + """Update rows with automatic access control attribute capture.""" + enhanced_data = dict(data) + + current_user = get_authenticated_user() + if current_user: + enhanced_data["owner_principal"] = current_user.principal + enhanced_data["access_attributes"] = current_user.attributes + else: + # IMPORTANT: Use empty string for owner_principal to match public access filter + enhanced_data["owner_principal"] = "" + enhanced_data["access_attributes"] = None # Will serialize as JSON null + + await self.sql_store.update(table, enhanced_data, where) + + async def delete(self, table: str, where: Mapping[str, Any]) -> None: + """Delete rows with automatic access control filtering.""" + await self.sql_store.delete(table, where) + + def _build_access_control_where_clause(self, policy: list[AccessRule]) -> str: + """Build SQL WHERE clause for access control filtering. + + Only applies SQL filtering for the default policy to ensure correctness. + For custom policies, uses conservative filtering to avoid blocking legitimate access. + """ + current_user = get_authenticated_user() + + if not policy or policy == SQL_OPTIMIZED_POLICY: + return self._build_default_policy_where_clause(current_user) + else: + return self._build_conservative_where_clause() + + def _json_extract(self, column: str, path: str) -> str: + """Extract JSON value (keeping JSON type). + + Args: + column: The JSON column name + path: The JSON path (e.g., 'roles', 'teams') + + Returns: + SQL expression to extract JSON value + """ + if self.database_type == StorageBackendType.SQL_POSTGRES.value: + return f"{column}->'{path}'" + elif self.database_type == StorageBackendType.SQL_SQLITE.value: + return f"JSON_EXTRACT({column}, '$.{path}')" + else: + raise ValueError(f"Unsupported database type: {self.database_type}") + + def _json_extract_text(self, column: str, path: str) -> str: + """Extract JSON value as text. + + Args: + column: The JSON column name + path: The JSON path (e.g., 'roles', 'teams') + + Returns: + SQL expression to extract JSON value as text + """ + if self.database_type == StorageBackendType.SQL_POSTGRES.value: + return f"{column}->>'{path}'" + elif self.database_type == StorageBackendType.SQL_SQLITE.value: + return f"JSON_EXTRACT({column}, '$.{path}')" + else: + raise ValueError(f"Unsupported database type: {self.database_type}") + + def _get_public_access_conditions(self) -> list[str]: + """Get the SQL conditions for public access. + + Public records are those with: + - owner_principal = '' (empty string) + - access_attributes is either SQL NULL or JSON null + + Note: Different databases serialize None differently: + - SQLite: None → JSON null (text = 'null') + - Postgres: None → SQL NULL (IS NULL) + """ + conditions = ["owner_principal = ''"] + if self.database_type == StorageBackendType.SQL_POSTGRES.value: + # Accept both SQL NULL and JSON null for Postgres compatibility + # This handles both old rows (SQL NULL) and new rows (JSON null) + conditions.append("(access_attributes IS NULL OR access_attributes::text = 'null')") + elif self.database_type == StorageBackendType.SQL_SQLITE.value: + # SQLite serializes None as JSON null + conditions.append("(access_attributes IS NULL OR access_attributes = 'null')") + else: + raise ValueError(f"Unsupported database type: {self.database_type}") + return conditions + + def _build_default_policy_where_clause(self, current_user: User | None) -> str: + """Build SQL WHERE clause for the default policy. + + Default policy: permit all actions when user in owners [roles, teams, projects, namespaces] + This means user must match ALL attribute categories that exist in the resource. + """ + base_conditions = self._get_public_access_conditions() + user_attr_conditions = [] + + if current_user and current_user.attributes: + for attr_key, user_values in current_user.attributes.items(): + if user_values: + value_conditions = [] + for value in user_values: + # Check if JSON array contains the value + escaped_value = value.replace("'", "''") + json_text = self._json_extract_text("access_attributes", attr_key) + value_conditions.append(f"({json_text} LIKE '%\"{escaped_value}\"%')") + + if value_conditions: + # Check if the category is missing (NULL) + category_missing = f"{self._json_extract('access_attributes', attr_key)} IS NULL" + user_matches_category = f"({' OR '.join(value_conditions)})" + user_attr_conditions.append(f"({category_missing} OR {user_matches_category})") + + if user_attr_conditions: + all_requirements_met = f"({' AND '.join(user_attr_conditions)})" + base_conditions.append(all_requirements_met) + + return f"({' OR '.join(base_conditions)})" + + def _build_conservative_where_clause(self) -> str: + """Conservative SQL filtering for custom policies. + + Only filters records we're 100% certain would be denied by any reasonable policy. + """ + current_user = get_authenticated_user() + + if not current_user: + # Only allow public records + base_conditions = self._get_public_access_conditions() + return f"({' OR '.join(base_conditions)})" + + return "1=1" diff --git a/src/llama_stack/core/storage/sqlstore/sqlalchemy_sqlstore.py b/src/llama_stack/core/storage/sqlstore/sqlalchemy_sqlstore.py new file mode 100644 index 000000000..10009d396 --- /dev/null +++ b/src/llama_stack/core/storage/sqlstore/sqlalchemy_sqlstore.py @@ -0,0 +1,374 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. +from collections.abc import Mapping, Sequence +from typing import Any, Literal, cast + +from sqlalchemy import ( + JSON, + Boolean, + Column, + DateTime, + Float, + Integer, + MetaData, + String, + Table, + Text, + event, + inspect, + select, + text, +) +from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine +from sqlalchemy.ext.asyncio.engine import AsyncEngine +from sqlalchemy.sql.elements import ColumnElement + +from llama_stack.core.storage.datatypes import SqlAlchemySqlStoreConfig +from llama_stack.log import get_logger +from llama_stack_api import PaginatedResponse + +from .api import ColumnDefinition, ColumnType, SqlStore + +logger = get_logger(name=__name__, category="providers::utils") + +TYPE_MAPPING: dict[ColumnType, Any] = { + ColumnType.INTEGER: Integer, + ColumnType.STRING: String, + ColumnType.FLOAT: Float, + ColumnType.BOOLEAN: Boolean, + ColumnType.DATETIME: DateTime, + ColumnType.TEXT: Text, + ColumnType.JSON: JSON, +} + + +def _build_where_expr(column: ColumnElement, value: Any) -> ColumnElement: + """Return a SQLAlchemy expression for a where condition. + + `value` may be a simple scalar (equality) or a mapping like {">": 123}. + The returned expression is a SQLAlchemy ColumnElement usable in query.where(...). + """ + if isinstance(value, Mapping): + if len(value) != 1: + raise ValueError(f"Operator mapping must have a single operator, got: {value}") + op, operand = next(iter(value.items())) + if op == "==" or op == "=": + return cast(ColumnElement[Any], column == operand) + if op == ">": + return cast(ColumnElement[Any], column > operand) + if op == "<": + return cast(ColumnElement[Any], column < operand) + if op == ">=": + return cast(ColumnElement[Any], column >= operand) + if op == "<=": + return cast(ColumnElement[Any], column <= operand) + raise ValueError(f"Unsupported operator '{op}' in where mapping") + return cast(ColumnElement[Any], column == value) + + +class SqlAlchemySqlStoreImpl(SqlStore): + def __init__(self, config: SqlAlchemySqlStoreConfig): + self.config = config + self._is_sqlite_backend = "sqlite" in self.config.engine_str + self.async_session = async_sessionmaker(self.create_engine()) + self.metadata = MetaData() + + def create_engine(self) -> AsyncEngine: + # Configure connection args for better concurrency support + connect_args = {} + if self._is_sqlite_backend: + # SQLite-specific optimizations for concurrent access + # With WAL mode, most locks resolve in milliseconds, but allow up to 5s for edge cases + connect_args["timeout"] = 5.0 + connect_args["check_same_thread"] = False # Allow usage across asyncio tasks + + engine = create_async_engine( + self.config.engine_str, + pool_pre_ping=True, + connect_args=connect_args, + ) + + # Enable WAL mode for SQLite to support concurrent readers and writers + if self._is_sqlite_backend: + + @event.listens_for(engine.sync_engine, "connect") + def set_sqlite_pragma(dbapi_conn, connection_record): + cursor = dbapi_conn.cursor() + # Enable Write-Ahead Logging for better concurrency + cursor.execute("PRAGMA journal_mode=WAL") + # Set busy timeout to 5 seconds (retry instead of immediate failure) + # With WAL mode, locks should be brief; if we hit 5s there's a bigger issue + cursor.execute("PRAGMA busy_timeout=5000") + # Use NORMAL synchronous mode for better performance (still safe with WAL) + cursor.execute("PRAGMA synchronous=NORMAL") + cursor.close() + + return engine + + async def create_table( + self, + table: str, + schema: Mapping[str, ColumnType | ColumnDefinition], + ) -> None: + if not schema: + raise ValueError(f"No columns defined for table '{table}'.") + + sqlalchemy_columns: list[Column] = [] + + for col_name, col_props in schema.items(): + col_type = None + is_primary_key = False + is_nullable = True + + if isinstance(col_props, ColumnType): + col_type = col_props + elif isinstance(col_props, ColumnDefinition): + col_type = col_props.type + is_primary_key = col_props.primary_key + is_nullable = col_props.nullable + + sqlalchemy_type = TYPE_MAPPING.get(col_type) + if not sqlalchemy_type: + raise ValueError(f"Unsupported column type '{col_type}' for column '{col_name}'.") + + sqlalchemy_columns.append( + Column(col_name, sqlalchemy_type, primary_key=is_primary_key, nullable=is_nullable) + ) + + if table not in self.metadata.tables: + sqlalchemy_table = Table(table, self.metadata, *sqlalchemy_columns) + else: + sqlalchemy_table = self.metadata.tables[table] + + engine = self.create_engine() + async with engine.begin() as conn: + await conn.run_sync(self.metadata.create_all, tables=[sqlalchemy_table], checkfirst=True) + + async def insert(self, table: str, data: Mapping[str, Any] | Sequence[Mapping[str, Any]]) -> None: + async with self.async_session() as session: + await session.execute(self.metadata.tables[table].insert(), data) + await session.commit() + + async def upsert( + self, + table: str, + data: Mapping[str, Any], + conflict_columns: list[str], + update_columns: list[str] | None = None, + ) -> None: + table_obj = self.metadata.tables[table] + dialect_insert = self._get_dialect_insert(table_obj) + insert_stmt = dialect_insert.values(**data) + + if update_columns is None: + update_columns = [col for col in data.keys() if col not in conflict_columns] + + update_mapping = {col: getattr(insert_stmt.excluded, col) for col in update_columns} + conflict_cols = [table_obj.c[col] for col in conflict_columns] + + stmt = insert_stmt.on_conflict_do_update(index_elements=conflict_cols, set_=update_mapping) + + async with self.async_session() as session: + await session.execute(stmt) + await session.commit() + + async def fetch_all( + self, + table: str, + where: Mapping[str, Any] | None = None, + where_sql: str | None = None, + limit: int | None = None, + order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None, + cursor: tuple[str, str] | None = None, + ) -> PaginatedResponse: + async with self.async_session() as session: + table_obj = self.metadata.tables[table] + query = select(table_obj) + + if where: + for key, value in where.items(): + query = query.where(_build_where_expr(table_obj.c[key], value)) + + if where_sql: + query = query.where(text(where_sql)) + + # Handle cursor-based pagination + if cursor: + # Validate cursor tuple format + if not isinstance(cursor, tuple) or len(cursor) != 2: + raise ValueError(f"Cursor must be a tuple of (key_column, cursor_id), got: {cursor}") + + # Require order_by for cursor pagination + if not order_by: + raise ValueError("order_by is required when using cursor pagination") + + # Only support single-column ordering for cursor pagination + if len(order_by) != 1: + raise ValueError( + f"Cursor pagination only supports single-column ordering, got {len(order_by)} columns" + ) + + cursor_key_column, cursor_id = cursor + order_column, order_direction = order_by[0] + + # Verify cursor_key_column exists + if cursor_key_column not in table_obj.c: + raise ValueError(f"Cursor key column '{cursor_key_column}' not found in table '{table}'") + + # Get cursor value for the order column + cursor_query = select(table_obj.c[order_column]).where(table_obj.c[cursor_key_column] == cursor_id) + cursor_result = await session.execute(cursor_query) + cursor_row = cursor_result.fetchone() + + if not cursor_row: + raise ValueError(f"Record with {cursor_key_column}='{cursor_id}' not found in table '{table}'") + + cursor_value = cursor_row[0] + + # Apply cursor condition based on sort direction + if order_direction == "desc": + query = query.where(table_obj.c[order_column] < cursor_value) + else: + query = query.where(table_obj.c[order_column] > cursor_value) + + # Apply ordering + if order_by: + if not isinstance(order_by, list): + raise ValueError( + f"order_by must be a list of tuples (column, order={['asc', 'desc']}), got {order_by}" + ) + for order in order_by: + if not isinstance(order, tuple): + raise ValueError( + f"order_by must be a list of tuples (column, order={['asc', 'desc']}), got {order_by}" + ) + name, order_type = order + if name not in table_obj.c: + raise ValueError(f"Column '{name}' not found in table '{table}'") + if order_type == "asc": + query = query.order_by(table_obj.c[name].asc()) + elif order_type == "desc": + query = query.order_by(table_obj.c[name].desc()) + else: + raise ValueError(f"Invalid order '{order_type}' for column '{name}'") + + # Fetch limit + 1 to determine has_more + fetch_limit = limit + if limit: + fetch_limit = limit + 1 + + if fetch_limit: + query = query.limit(fetch_limit) + + result = await session.execute(query) + # Iterate directly - if no rows, list comprehension yields empty list + rows = [dict(row._mapping) for row in result] + + # Always return pagination result + has_more = False + if limit and len(rows) > limit: + has_more = True + rows = rows[:limit] + + return PaginatedResponse(data=rows, has_more=has_more) + + async def fetch_one( + self, + table: str, + where: Mapping[str, Any] | None = None, + where_sql: str | None = None, + order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None, + ) -> dict[str, Any] | None: + result = await self.fetch_all(table, where, where_sql, limit=1, order_by=order_by) + if not result.data: + return None + return result.data[0] + + async def update( + self, + table: str, + data: Mapping[str, Any], + where: Mapping[str, Any], + ) -> None: + if not where: + raise ValueError("where is required for update") + + async with self.async_session() as session: + stmt = self.metadata.tables[table].update() + for key, value in where.items(): + stmt = stmt.where(_build_where_expr(self.metadata.tables[table].c[key], value)) + await session.execute(stmt, data) + await session.commit() + + async def delete(self, table: str, where: Mapping[str, Any]) -> None: + if not where: + raise ValueError("where is required for delete") + + async with self.async_session() as session: + stmt = self.metadata.tables[table].delete() + for key, value in where.items(): + stmt = stmt.where(_build_where_expr(self.metadata.tables[table].c[key], value)) + await session.execute(stmt) + await session.commit() + + async def add_column_if_not_exists( + self, + table: str, + column_name: str, + column_type: ColumnType, + nullable: bool = True, + ) -> None: + """Add a column to an existing table if the column doesn't already exist.""" + engine = self.create_engine() + + try: + async with engine.begin() as conn: + + def check_column_exists(sync_conn): + inspector = inspect(sync_conn) + + table_names = inspector.get_table_names() + if table not in table_names: + return False, False # table doesn't exist, column doesn't exist + + existing_columns = inspector.get_columns(table) + column_names = [col["name"] for col in existing_columns] + + return True, column_name in column_names # table exists, column exists or not + + table_exists, column_exists = await conn.run_sync(check_column_exists) + if not table_exists or column_exists: + return + + sqlalchemy_type = TYPE_MAPPING.get(column_type) + if not sqlalchemy_type: + raise ValueError(f"Unsupported column type '{column_type}' for column '{column_name}'.") + + # Create the ALTER TABLE statement + # Note: We need to get the dialect-specific type name + dialect = engine.dialect + type_impl = sqlalchemy_type() + compiled_type = type_impl.compile(dialect=dialect) + + nullable_clause = "" if nullable else " NOT NULL" + add_column_sql = text(f"ALTER TABLE {table} ADD COLUMN {column_name} {compiled_type}{nullable_clause}") + + await conn.execute(add_column_sql) + except Exception as e: + # If any error occurs during migration, log it but don't fail + # The table creation will handle adding the column + logger.error(f"Error adding column {column_name} to table {table}: {e}") + pass + + def _get_dialect_insert(self, table: Table): + if self._is_sqlite_backend: + from sqlalchemy.dialects.sqlite import insert as sqlite_insert + + return sqlite_insert(table) + else: + from sqlalchemy.dialects.postgresql import insert as pg_insert + + return pg_insert(table) diff --git a/src/llama_stack/core/storage/sqlstore/sqlstore.py b/src/llama_stack/core/storage/sqlstore/sqlstore.py new file mode 100644 index 000000000..9409b7d00 --- /dev/null +++ b/src/llama_stack/core/storage/sqlstore/sqlstore.py @@ -0,0 +1,88 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from threading import Lock +from typing import Annotated, cast + +from pydantic import Field + +from llama_stack.core.storage.datatypes import ( + PostgresSqlStoreConfig, + SqliteSqlStoreConfig, + SqlStoreReference, + StorageBackendConfig, + StorageBackendType, +) + +from .api import SqlStore + +sql_store_pip_packages = ["sqlalchemy[asyncio]", "aiosqlite", "asyncpg"] + +_SQLSTORE_BACKENDS: dict[str, StorageBackendConfig] = {} +_SQLSTORE_INSTANCES: dict[str, SqlStore] = {} +_SQLSTORE_LOCKS: dict[str, Lock] = {} + + +SqlStoreConfig = Annotated[ + SqliteSqlStoreConfig | PostgresSqlStoreConfig, + Field(discriminator="type"), +] + + +def get_pip_packages(store_config: dict | SqlStoreConfig) -> list[str]: + """Get pip packages for SQL store config, handling both dict and object cases.""" + if isinstance(store_config, dict): + store_type = store_config.get("type") + if store_type == StorageBackendType.SQL_SQLITE.value: + return SqliteSqlStoreConfig.pip_packages() + elif store_type == StorageBackendType.SQL_POSTGRES.value: + return PostgresSqlStoreConfig.pip_packages() + else: + raise ValueError(f"Unknown SQL store type: {store_type}") + else: + return store_config.pip_packages() + + +def sqlstore_impl(reference: SqlStoreReference) -> SqlStore: + backend_name = reference.backend + + backend_config = _SQLSTORE_BACKENDS.get(backend_name) + if backend_config is None: + raise ValueError( + f"Unknown SQL store backend '{backend_name}'. Registered backends: {sorted(_SQLSTORE_BACKENDS)}" + ) + + existing = _SQLSTORE_INSTANCES.get(backend_name) + if existing: + return existing + + lock = _SQLSTORE_LOCKS.setdefault(backend_name, Lock()) + with lock: + existing = _SQLSTORE_INSTANCES.get(backend_name) + if existing: + return existing + + if isinstance(backend_config, SqliteSqlStoreConfig | PostgresSqlStoreConfig): + from .sqlalchemy_sqlstore import SqlAlchemySqlStoreImpl + + config = cast(SqliteSqlStoreConfig | PostgresSqlStoreConfig, backend_config).model_copy() + instance = SqlAlchemySqlStoreImpl(config) + _SQLSTORE_INSTANCES[backend_name] = instance + return instance + else: + raise ValueError(f"Unknown sqlstore type {backend_config.type}") + + +def register_sqlstore_backends(backends: dict[str, StorageBackendConfig]) -> None: + """Register the set of available SQL store backends for reference resolution.""" + global _SQLSTORE_BACKENDS + global _SQLSTORE_INSTANCES + + _SQLSTORE_BACKENDS.clear() + _SQLSTORE_INSTANCES.clear() + _SQLSTORE_LOCKS.clear() + for name, cfg in backends.items(): + _SQLSTORE_BACKENDS[name] = cfg diff --git a/src/llama_stack/core/store/registry.py b/src/llama_stack/core/store/registry.py index 6ff9e575b..1f604ba73 100644 --- a/src/llama_stack/core/store/registry.py +++ b/src/llama_stack/core/store/registry.py @@ -13,7 +13,7 @@ import pydantic from llama_stack.core.datatypes import RoutableObjectWithProvider from llama_stack.core.storage.datatypes import KVStoreReference from llama_stack.log import get_logger -from llama_stack.providers.utils.kvstore import KVStore, kvstore_impl +from llama_stack.core.storage.kvstore import KVStore, kvstore_impl logger = get_logger(__name__, category="core::registry") diff --git a/src/llama_stack/distributions/starter/starter.py b/src/llama_stack/distributions/starter/starter.py index 4c21a8c99..51784a0f2 100644 --- a/src/llama_stack/distributions/starter/starter.py +++ b/src/llama_stack/distributions/starter/starter.py @@ -35,8 +35,8 @@ from llama_stack.providers.remote.vector_io.pgvector.config import ( ) from llama_stack.providers.remote.vector_io.qdrant.config import QdrantVectorIOConfig from llama_stack.providers.remote.vector_io.weaviate.config import WeaviateVectorIOConfig -from llama_stack.providers.utils.kvstore.config import PostgresKVStoreConfig -from llama_stack.providers.utils.sqlstore.sqlstore import PostgresSqlStoreConfig +from llama_stack.core.storage.kvstore.config import PostgresKVStoreConfig +from llama_stack.core.storage.sqlstore.sqlstore import PostgresSqlStoreConfig from llama_stack_api import RemoteProviderSpec diff --git a/src/llama_stack/distributions/template.py b/src/llama_stack/distributions/template.py index 5755a26de..723420ec8 100644 --- a/src/llama_stack/distributions/template.py +++ b/src/llama_stack/distributions/template.py @@ -38,10 +38,10 @@ from llama_stack.core.storage.datatypes import ( from llama_stack.core.utils.dynamic import instantiate_class_type from llama_stack.core.utils.image_types import LlamaStackImageType from llama_stack.providers.utils.inference.model_registry import ProviderModelEntry -from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig -from llama_stack.providers.utils.kvstore.config import get_pip_packages as get_kv_pip_packages -from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig -from llama_stack.providers.utils.sqlstore.sqlstore import get_pip_packages as get_sql_pip_packages +from llama_stack.core.storage.kvstore.config import SqliteKVStoreConfig +from llama_stack.core.storage.kvstore.config import get_pip_packages as get_kv_pip_packages +from llama_stack.core.storage.sqlstore.sqlstore import SqliteSqlStoreConfig +from llama_stack.core.storage.sqlstore.sqlstore import get_pip_packages as get_sql_pip_packages from llama_stack_api import DatasetPurpose, ModelType diff --git a/src/llama_stack/providers/inline/agents/meta_reference/agents.py b/src/llama_stack/providers/inline/agents/meta_reference/agents.py index 347f6fdb1..6777b10e5 100644 --- a/src/llama_stack/providers/inline/agents/meta_reference/agents.py +++ b/src/llama_stack/providers/inline/agents/meta_reference/agents.py @@ -7,7 +7,7 @@ from llama_stack.core.datatypes import AccessRule from llama_stack.log import get_logger -from llama_stack.providers.utils.kvstore import InmemoryKVStoreImpl, kvstore_impl +from llama_stack.core.storage.kvstore import InmemoryKVStoreImpl, kvstore_impl from llama_stack.providers.utils.responses.responses_store import ResponsesStore from llama_stack_api import ( Agents, diff --git a/src/llama_stack/providers/inline/batches/reference/__init__.py b/src/llama_stack/providers/inline/batches/reference/__init__.py index 11c4b06a9..b48c82864 100644 --- a/src/llama_stack/providers/inline/batches/reference/__init__.py +++ b/src/llama_stack/providers/inline/batches/reference/__init__.py @@ -7,7 +7,7 @@ from typing import Any from llama_stack.core.datatypes import AccessRule, Api -from llama_stack.providers.utils.kvstore import kvstore_impl +from llama_stack.core.storage.kvstore import kvstore_impl from llama_stack_api import Files, Inference, Models from .batches import ReferenceBatchesImpl diff --git a/src/llama_stack/providers/inline/batches/reference/batches.py b/src/llama_stack/providers/inline/batches/reference/batches.py index 73727799d..3969f8f2a 100644 --- a/src/llama_stack/providers/inline/batches/reference/batches.py +++ b/src/llama_stack/providers/inline/batches/reference/batches.py @@ -17,7 +17,7 @@ from openai.types.batch import BatchError, Errors from pydantic import BaseModel from llama_stack.log import get_logger -from llama_stack.providers.utils.kvstore import KVStore +from llama_stack.core.storage.kvstore import KVStore from llama_stack_api import ( Batches, BatchObject, diff --git a/src/llama_stack/providers/inline/datasetio/localfs/datasetio.py b/src/llama_stack/providers/inline/datasetio/localfs/datasetio.py index 6ab1a540f..be66c8ba5 100644 --- a/src/llama_stack/providers/inline/datasetio/localfs/datasetio.py +++ b/src/llama_stack/providers/inline/datasetio/localfs/datasetio.py @@ -6,7 +6,7 @@ from typing import Any from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_uri -from llama_stack.providers.utils.kvstore import kvstore_impl +from llama_stack.core.storage.kvstore import kvstore_impl from llama_stack.providers.utils.pagination import paginate_records from llama_stack_api import Dataset, DatasetIO, DatasetsProtocolPrivate, PaginatedResponse diff --git a/src/llama_stack/providers/inline/eval/meta_reference/eval.py b/src/llama_stack/providers/inline/eval/meta_reference/eval.py index d43e569e2..f2459c5e0 100644 --- a/src/llama_stack/providers/inline/eval/meta_reference/eval.py +++ b/src/llama_stack/providers/inline/eval/meta_reference/eval.py @@ -9,7 +9,7 @@ from typing import Any from tqdm import tqdm from llama_stack.providers.utils.common.data_schema_validator import ColumnName -from llama_stack.providers.utils.kvstore import kvstore_impl +from llama_stack.core.storage.kvstore import kvstore_impl from llama_stack_api import ( Agents, Benchmark, diff --git a/src/llama_stack/providers/inline/files/localfs/files.py b/src/llama_stack/providers/inline/files/localfs/files.py index 5fb35a378..0656a334d 100644 --- a/src/llama_stack/providers/inline/files/localfs/files.py +++ b/src/llama_stack/providers/inline/files/localfs/files.py @@ -15,9 +15,8 @@ from llama_stack.core.datatypes import AccessRule from llama_stack.core.id_generation import generate_object_id from llama_stack.log import get_logger from llama_stack.providers.utils.files.form_data import parse_expires_after -from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType -from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore -from llama_stack.providers.utils.sqlstore.sqlstore import sqlstore_impl +from llama_stack.core.storage.sqlstore.authorized_sqlstore import AuthorizedSqlStore +from llama_stack.core.storage.sqlstore.sqlstore import sqlstore_impl from llama_stack_api import ( ExpiresAfter, Files, @@ -28,6 +27,7 @@ from llama_stack_api import ( Order, ResourceNotFoundError, ) +from llama_stack_api.internal.sqlstore import ColumnDefinition, ColumnType from .config import LocalfsFilesImplConfig diff --git a/src/llama_stack/providers/inline/vector_io/faiss/faiss.py b/src/llama_stack/providers/inline/vector_io/faiss/faiss.py index d52a54e6a..9f73e8f9f 100644 --- a/src/llama_stack/providers/inline/vector_io/faiss/faiss.py +++ b/src/llama_stack/providers/inline/vector_io/faiss/faiss.py @@ -15,8 +15,7 @@ import numpy as np from numpy.typing import NDArray from llama_stack.log import get_logger -from llama_stack.providers.utils.kvstore import kvstore_impl -from llama_stack.providers.utils.kvstore.api import KVStore +from llama_stack.core.storage.kvstore import kvstore_impl from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin from llama_stack.providers.utils.memory.vector_store import ChunkForDeletion, EmbeddingIndex, VectorStoreWithIndex from llama_stack_api import ( @@ -32,6 +31,7 @@ from llama_stack_api import ( VectorStoreNotFoundError, VectorStoresProtocolPrivate, ) +from llama_stack_api.internal.kvstore import KVStore from .config import FaissVectorIOConfig diff --git a/src/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py b/src/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py index 74bc349a5..6d29915cf 100644 --- a/src/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py +++ b/src/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py @@ -15,8 +15,7 @@ import sqlite_vec # type: ignore[import-untyped] from numpy.typing import NDArray from llama_stack.log import get_logger -from llama_stack.providers.utils.kvstore import kvstore_impl -from llama_stack.providers.utils.kvstore.api import KVStore +from llama_stack.core.storage.kvstore import kvstore_impl from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin from llama_stack.providers.utils.memory.vector_store import ( RERANKER_TYPE_RRF, @@ -35,6 +34,7 @@ from llama_stack_api import ( VectorStoreNotFoundError, VectorStoresProtocolPrivate, ) +from llama_stack_api.internal.kvstore import KVStore logger = get_logger(name=__name__, category="vector_io") diff --git a/src/llama_stack/providers/registry/agents.py b/src/llama_stack/providers/registry/agents.py index 455be1ae7..2c68750a6 100644 --- a/src/llama_stack/providers/registry/agents.py +++ b/src/llama_stack/providers/registry/agents.py @@ -5,7 +5,7 @@ # the root directory of this source tree. -from llama_stack.providers.utils.kvstore import kvstore_dependencies +from llama_stack.core.storage.kvstore import kvstore_dependencies from llama_stack_api import ( Api, InlineProviderSpec, diff --git a/src/llama_stack/providers/registry/files.py b/src/llama_stack/providers/registry/files.py index 024254b57..8ce8acd91 100644 --- a/src/llama_stack/providers/registry/files.py +++ b/src/llama_stack/providers/registry/files.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from llama_stack.providers.utils.sqlstore.sqlstore import sql_store_pip_packages +from llama_stack.core.storage.sqlstore.sqlstore import sql_store_pip_packages from llama_stack_api import Api, InlineProviderSpec, ProviderSpec, RemoteProviderSpec diff --git a/src/llama_stack/providers/remote/datasetio/huggingface/huggingface.py b/src/llama_stack/providers/remote/datasetio/huggingface/huggingface.py index 72069f716..26390a63b 100644 --- a/src/llama_stack/providers/remote/datasetio/huggingface/huggingface.py +++ b/src/llama_stack/providers/remote/datasetio/huggingface/huggingface.py @@ -6,7 +6,7 @@ from typing import Any from urllib.parse import parse_qs, urlparse -from llama_stack.providers.utils.kvstore import kvstore_impl +from llama_stack.core.storage.kvstore import kvstore_impl from llama_stack.providers.utils.pagination import paginate_records from llama_stack_api import Dataset, DatasetIO, DatasetsProtocolPrivate, PaginatedResponse diff --git a/src/llama_stack/providers/remote/files/openai/files.py b/src/llama_stack/providers/remote/files/openai/files.py index d2f5a08eb..41192ddf7 100644 --- a/src/llama_stack/providers/remote/files/openai/files.py +++ b/src/llama_stack/providers/remote/files/openai/files.py @@ -11,9 +11,8 @@ from fastapi import Depends, File, Form, Response, UploadFile from llama_stack.core.datatypes import AccessRule from llama_stack.providers.utils.files.form_data import parse_expires_after -from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType -from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore -from llama_stack.providers.utils.sqlstore.sqlstore import sqlstore_impl +from llama_stack.core.storage.sqlstore.authorized_sqlstore import AuthorizedSqlStore +from llama_stack.core.storage.sqlstore.sqlstore import sqlstore_impl from llama_stack_api import ( ExpiresAfter, Files, @@ -24,6 +23,7 @@ from llama_stack_api import ( Order, ResourceNotFoundError, ) +from llama_stack_api.internal.sqlstore import ColumnDefinition, ColumnType from openai import OpenAI from .config import OpenAIFilesImplConfig diff --git a/src/llama_stack/providers/remote/files/s3/files.py b/src/llama_stack/providers/remote/files/s3/files.py index 68822eb77..09a0a13a1 100644 --- a/src/llama_stack/providers/remote/files/s3/files.py +++ b/src/llama_stack/providers/remote/files/s3/files.py @@ -20,9 +20,8 @@ if TYPE_CHECKING: from llama_stack.core.datatypes import AccessRule from llama_stack.core.id_generation import generate_object_id from llama_stack.providers.utils.files.form_data import parse_expires_after -from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType -from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore -from llama_stack.providers.utils.sqlstore.sqlstore import sqlstore_impl +from llama_stack.core.storage.sqlstore.authorized_sqlstore import AuthorizedSqlStore +from llama_stack.core.storage.sqlstore.sqlstore import sqlstore_impl from llama_stack_api import ( ExpiresAfter, Files, @@ -33,6 +32,7 @@ from llama_stack_api import ( Order, ResourceNotFoundError, ) +from llama_stack_api.internal.sqlstore import ColumnDefinition, ColumnType from .config import S3FilesImplConfig diff --git a/src/llama_stack/providers/remote/vector_io/chroma/chroma.py b/src/llama_stack/providers/remote/vector_io/chroma/chroma.py index 645b40661..0be9f2f36 100644 --- a/src/llama_stack/providers/remote/vector_io/chroma/chroma.py +++ b/src/llama_stack/providers/remote/vector_io/chroma/chroma.py @@ -13,8 +13,7 @@ from numpy.typing import NDArray from llama_stack.log import get_logger from llama_stack.providers.inline.vector_io.chroma import ChromaVectorIOConfig as InlineChromaVectorIOConfig -from llama_stack.providers.utils.kvstore import kvstore_impl -from llama_stack.providers.utils.kvstore.api import KVStore +from llama_stack.core.storage.kvstore import kvstore_impl from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin from llama_stack.providers.utils.memory.vector_store import ChunkForDeletion, EmbeddingIndex, VectorStoreWithIndex from llama_stack_api import ( @@ -27,6 +26,7 @@ from llama_stack_api import ( VectorStore, VectorStoresProtocolPrivate, ) +from llama_stack_api.internal.kvstore import KVStore from .config import ChromaVectorIOConfig as RemoteChromaVectorIOConfig diff --git a/src/llama_stack/providers/remote/vector_io/milvus/milvus.py b/src/llama_stack/providers/remote/vector_io/milvus/milvus.py index aefa20317..13043d5d0 100644 --- a/src/llama_stack/providers/remote/vector_io/milvus/milvus.py +++ b/src/llama_stack/providers/remote/vector_io/milvus/milvus.py @@ -13,8 +13,7 @@ from pymilvus import AnnSearchRequest, DataType, Function, FunctionType, MilvusC from llama_stack.log import get_logger from llama_stack.providers.inline.vector_io.milvus import MilvusVectorIOConfig as InlineMilvusVectorIOConfig -from llama_stack.providers.utils.kvstore import kvstore_impl -from llama_stack.providers.utils.kvstore.api import KVStore +from llama_stack.core.storage.kvstore import kvstore_impl from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin from llama_stack.providers.utils.memory.vector_store import ( RERANKER_TYPE_WEIGHTED, @@ -34,6 +33,7 @@ from llama_stack_api import ( VectorStoreNotFoundError, VectorStoresProtocolPrivate, ) +from llama_stack_api.internal.kvstore import KVStore from .config import MilvusVectorIOConfig as RemoteMilvusVectorIOConfig diff --git a/src/llama_stack/providers/remote/vector_io/pgvector/pgvector.py b/src/llama_stack/providers/remote/vector_io/pgvector/pgvector.py index 2901bad97..f46f721c6 100644 --- a/src/llama_stack/providers/remote/vector_io/pgvector/pgvector.py +++ b/src/llama_stack/providers/remote/vector_io/pgvector/pgvector.py @@ -15,8 +15,7 @@ from pydantic import BaseModel, TypeAdapter from llama_stack.log import get_logger from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str -from llama_stack.providers.utils.kvstore import kvstore_impl -from llama_stack.providers.utils.kvstore.api import KVStore +from llama_stack.core.storage.kvstore import kvstore_impl from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin from llama_stack.providers.utils.memory.vector_store import ChunkForDeletion, EmbeddingIndex, VectorStoreWithIndex from llama_stack.providers.utils.vector_io.vector_utils import WeightedInMemoryAggregator, sanitize_collection_name @@ -31,6 +30,7 @@ from llama_stack_api import ( VectorStoreNotFoundError, VectorStoresProtocolPrivate, ) +from llama_stack_api.internal.kvstore import KVStore from .config import PGVectorVectorIOConfig diff --git a/src/llama_stack/providers/remote/vector_io/qdrant/qdrant.py b/src/llama_stack/providers/remote/vector_io/qdrant/qdrant.py index 20ab653d0..721081f96 100644 --- a/src/llama_stack/providers/remote/vector_io/qdrant/qdrant.py +++ b/src/llama_stack/providers/remote/vector_io/qdrant/qdrant.py @@ -15,7 +15,7 @@ from qdrant_client.models import PointStruct from llama_stack.log import get_logger from llama_stack.providers.inline.vector_io.qdrant import QdrantVectorIOConfig as InlineQdrantVectorIOConfig -from llama_stack.providers.utils.kvstore import kvstore_impl +from llama_stack.core.storage.kvstore import kvstore_impl from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin from llama_stack.providers.utils.memory.vector_store import ChunkForDeletion, EmbeddingIndex, VectorStoreWithIndex from llama_stack_api import ( diff --git a/src/llama_stack/providers/remote/vector_io/weaviate/weaviate.py b/src/llama_stack/providers/remote/vector_io/weaviate/weaviate.py index ba3e6b7ea..6d98b0a69 100644 --- a/src/llama_stack/providers/remote/vector_io/weaviate/weaviate.py +++ b/src/llama_stack/providers/remote/vector_io/weaviate/weaviate.py @@ -14,8 +14,7 @@ from weaviate.classes.query import Filter, HybridFusion from llama_stack.core.request_headers import NeedsRequestProviderData from llama_stack.log import get_logger -from llama_stack.providers.utils.kvstore import kvstore_impl -from llama_stack.providers.utils.kvstore.api import KVStore +from llama_stack.core.storage.kvstore import kvstore_impl from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin from llama_stack.providers.utils.memory.vector_store import ( RERANKER_TYPE_RRF, @@ -35,6 +34,7 @@ from llama_stack_api import ( VectorStoreNotFoundError, VectorStoresProtocolPrivate, ) +from llama_stack_api.internal.kvstore import KVStore from .config import WeaviateVectorIOConfig diff --git a/src/llama_stack/providers/utils/inference/inference_store.py b/src/llama_stack/providers/utils/inference/inference_store.py index 49e3af7a1..2260f3247 100644 --- a/src/llama_stack/providers/utils/inference/inference_store.py +++ b/src/llama_stack/providers/utils/inference/inference_store.py @@ -19,9 +19,9 @@ from llama_stack_api import ( Order, ) -from ..sqlstore.api import ColumnDefinition, ColumnType -from ..sqlstore.authorized_sqlstore import AuthorizedSqlStore -from ..sqlstore.sqlstore import _SQLSTORE_BACKENDS, sqlstore_impl +from llama_stack.core.storage.sqlstore.api import ColumnDefinition, ColumnType +from llama_stack.core.storage.sqlstore.authorized_sqlstore import AuthorizedSqlStore +from llama_stack.core.storage.sqlstore.sqlstore import _SQLSTORE_BACKENDS, sqlstore_impl logger = get_logger(name=__name__, category="inference") diff --git a/src/llama_stack/providers/utils/kvstore/__init__.py b/src/llama_stack/providers/utils/kvstore/__init__.py index 470a75d2d..129340d7f 100644 --- a/src/llama_stack/providers/utils/kvstore/__init__.py +++ b/src/llama_stack/providers/utils/kvstore/__init__.py @@ -4,4 +4,4 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from .kvstore import * # noqa: F401, F403 +from llama_stack.core.storage.kvstore import * # noqa: F401,F403 diff --git a/src/llama_stack/providers/utils/kvstore/api.py b/src/llama_stack/providers/utils/kvstore/api.py index d17dc66e1..4bf25dc12 100644 --- a/src/llama_stack/providers/utils/kvstore/api.py +++ b/src/llama_stack/providers/utils/kvstore/api.py @@ -4,18 +4,4 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from datetime import datetime -from typing import Protocol - - -class KVStore(Protocol): - # TODO: make the value type bytes instead of str - async def set(self, key: str, value: str, expiration: datetime | None = None) -> None: ... - - async def get(self, key: str) -> str | None: ... - - async def delete(self, key: str) -> None: ... - - async def values_in_range(self, start_key: str, end_key: str) -> list[str]: ... - - async def keys_in_range(self, start_key: str, end_key: str) -> list[str]: ... +from llama_stack.core.storage.kvstore.api import * # noqa: F401,F403 diff --git a/src/llama_stack/providers/utils/kvstore/config.py b/src/llama_stack/providers/utils/kvstore/config.py index c0582abc4..841bfff91 100644 --- a/src/llama_stack/providers/utils/kvstore/config.py +++ b/src/llama_stack/providers/utils/kvstore/config.py @@ -4,36 +4,4 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Annotated - -from pydantic import Field - -from llama_stack.core.storage.datatypes import ( - MongoDBKVStoreConfig, - PostgresKVStoreConfig, - RedisKVStoreConfig, - SqliteKVStoreConfig, - StorageBackendType, -) - -KVStoreConfig = Annotated[ - RedisKVStoreConfig | SqliteKVStoreConfig | PostgresKVStoreConfig | MongoDBKVStoreConfig, Field(discriminator="type") -] - - -def get_pip_packages(store_config: dict | KVStoreConfig) -> list[str]: - """Get pip packages for KV store config, handling both dict and object cases.""" - if isinstance(store_config, dict): - store_type = store_config.get("type") - if store_type == StorageBackendType.KV_SQLITE.value: - return SqliteKVStoreConfig.pip_packages() - elif store_type == StorageBackendType.KV_POSTGRES.value: - return PostgresKVStoreConfig.pip_packages() - elif store_type == StorageBackendType.KV_REDIS.value: - return RedisKVStoreConfig.pip_packages() - elif store_type == StorageBackendType.KV_MONGODB.value: - return MongoDBKVStoreConfig.pip_packages() - else: - raise ValueError(f"Unknown KV store type: {store_type}") - else: - return store_config.pip_packages() +from llama_stack.core.storage.kvstore.config import * # noqa: F401,F403 diff --git a/src/llama_stack/providers/utils/kvstore/kvstore.py b/src/llama_stack/providers/utils/kvstore/kvstore.py index 5b8d77102..180d915d7 100644 --- a/src/llama_stack/providers/utils/kvstore/kvstore.py +++ b/src/llama_stack/providers/utils/kvstore/kvstore.py @@ -4,115 +4,4 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from __future__ import annotations - -import asyncio -from collections import defaultdict - -from llama_stack.core.storage.datatypes import KVStoreReference, StorageBackendConfig, StorageBackendType - -from .api import KVStore -from .config import KVStoreConfig - - -def kvstore_dependencies(): - """ - Returns all possible kvstore dependencies for registry/provider specifications. - - NOTE: For specific kvstore implementations, use config.pip_packages instead. - This function returns the union of all dependencies for cases where the specific - kvstore type is not known at declaration time (e.g., provider registries). - """ - return ["aiosqlite", "psycopg2-binary", "redis", "pymongo"] - - -class InmemoryKVStoreImpl(KVStore): - def __init__(self): - self._store = {} - - async def initialize(self) -> None: - pass - - async def get(self, key: str) -> str | None: - return self._store.get(key) - - async def set(self, key: str, value: str) -> None: - self._store[key] = value - - async def values_in_range(self, start_key: str, end_key: str) -> list[str]: - return [self._store[key] for key in self._store.keys() if key >= start_key and key < end_key] - - async def keys_in_range(self, start_key: str, end_key: str) -> list[str]: - """Get all keys in the given range.""" - return [key for key in self._store.keys() if key >= start_key and key < end_key] - - async def delete(self, key: str) -> None: - del self._store[key] - - -_KVSTORE_BACKENDS: dict[str, KVStoreConfig] = {} -_KVSTORE_INSTANCES: dict[tuple[str, str], KVStore] = {} -_KVSTORE_LOCKS: defaultdict[tuple[str, str], asyncio.Lock] = defaultdict(asyncio.Lock) - - -def register_kvstore_backends(backends: dict[str, StorageBackendConfig]) -> None: - """Register the set of available KV store backends for reference resolution.""" - global _KVSTORE_BACKENDS - global _KVSTORE_INSTANCES - global _KVSTORE_LOCKS - - _KVSTORE_BACKENDS.clear() - _KVSTORE_INSTANCES.clear() - _KVSTORE_LOCKS.clear() - for name, cfg in backends.items(): - _KVSTORE_BACKENDS[name] = cfg - - -async def kvstore_impl(reference: KVStoreReference) -> KVStore: - backend_name = reference.backend - cache_key = (backend_name, reference.namespace) - - existing = _KVSTORE_INSTANCES.get(cache_key) - if existing: - return existing - - backend_config = _KVSTORE_BACKENDS.get(backend_name) - if backend_config is None: - raise ValueError(f"Unknown KVStore backend '{backend_name}'. Registered backends: {sorted(_KVSTORE_BACKENDS)}") - - lock = _KVSTORE_LOCKS[cache_key] - async with lock: - existing = _KVSTORE_INSTANCES.get(cache_key) - if existing: - return existing - - config = backend_config.model_copy() - config.namespace = reference.namespace - - if config.type == StorageBackendType.KV_REDIS.value: - from .redis import RedisKVStoreImpl - - impl = RedisKVStoreImpl(config) - elif config.type == StorageBackendType.KV_SQLITE.value: - from .sqlite import SqliteKVStoreImpl - - impl = SqliteKVStoreImpl(config) - elif config.type == StorageBackendType.KV_POSTGRES.value: - from .postgres import PostgresKVStoreImpl - - impl = PostgresKVStoreImpl(config) - elif config.type == StorageBackendType.KV_MONGODB.value: - from .mongodb import MongoDBKVStoreImpl - - impl = MongoDBKVStoreImpl(config) - else: - raise ValueError(f"Unknown kvstore type {config.type}") - - await impl.initialize() - _KVSTORE_INSTANCES[cache_key] = impl - return impl +from llama_stack.core.storage.kvstore.kvstore import * # noqa: F401,F403 diff --git a/src/llama_stack/providers/utils/kvstore/mongodb/__init__.py b/src/llama_stack/providers/utils/kvstore/mongodb/__init__.py index a42fee76f..2a7efc55d 100644 --- a/src/llama_stack/providers/utils/kvstore/mongodb/__init__.py +++ b/src/llama_stack/providers/utils/kvstore/mongodb/__init__.py @@ -4,6 +4,4 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from .mongodb import MongoDBKVStoreImpl - -__all__ = ["MongoDBKVStoreImpl"] +from llama_stack.core.storage.kvstore.mongodb import * # noqa: F401,F403 diff --git a/src/llama_stack/providers/utils/kvstore/mongodb/mongodb.py b/src/llama_stack/providers/utils/kvstore/mongodb/mongodb.py index 964c45090..df2631689 100644 --- a/src/llama_stack/providers/utils/kvstore/mongodb/mongodb.py +++ b/src/llama_stack/providers/utils/kvstore/mongodb/mongodb.py @@ -4,82 +4,4 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from datetime import datetime - -from pymongo import AsyncMongoClient -from pymongo.asynchronous.collection import AsyncCollection - -from llama_stack.log import get_logger -from llama_stack.providers.utils.kvstore import KVStore - -from ..config import MongoDBKVStoreConfig - -log = get_logger(name=__name__, category="providers::utils") - - -class MongoDBKVStoreImpl(KVStore): - def __init__(self, config: MongoDBKVStoreConfig): - self.config = config - self.conn: AsyncMongoClient | None = None - - @property - def collection(self) -> AsyncCollection: - if self.conn is None: - raise RuntimeError("MongoDB connection is not initialized") - return self.conn[self.config.db][self.config.collection_name] - - async def initialize(self) -> None: - try: - # Pass parameters explicitly to satisfy mypy - AsyncMongoClient doesn't accept **dict - self.conn = AsyncMongoClient( - host=self.config.host if self.config.host is not None else None, - port=self.config.port if self.config.port is not None else None, - username=self.config.user if self.config.user is not None else None, - password=self.config.password if self.config.password is not None else None, - ) - except Exception as e: - log.exception("Could not connect to MongoDB database server") - raise RuntimeError("Could not connect to MongoDB database server") from e - - def _namespaced_key(self, key: str) -> str: - if not self.config.namespace: - return key - return f"{self.config.namespace}:{key}" - - async def set(self, key: str, value: str, expiration: datetime | None = None) -> None: - key = self._namespaced_key(key) - update_query = {"$set": {"value": value, "expiration": expiration}} - await self.collection.update_one({"key": key}, update_query, upsert=True) - - async def get(self, key: str) -> str | None: - key = self._namespaced_key(key) - query = {"key": key} - result = await self.collection.find_one(query, {"value": 1, "_id": 0}) - return result["value"] if result else None - - async def delete(self, key: str) -> None: - key = self._namespaced_key(key) - await self.collection.delete_one({"key": key}) - - async def values_in_range(self, start_key: str, end_key: str) -> list[str]: - start_key = self._namespaced_key(start_key) - end_key = self._namespaced_key(end_key) - query = { - "key": {"$gte": start_key, "$lt": end_key}, - } - cursor = self.collection.find(query, {"value": 1, "_id": 0}).sort("key", 1) - result = [] - async for doc in cursor: - result.append(doc["value"]) - return result - - async def keys_in_range(self, start_key: str, end_key: str) -> list[str]: - start_key = self._namespaced_key(start_key) - end_key = self._namespaced_key(end_key) - query = {"key": {"$gte": start_key, "$lt": end_key}} - cursor = self.collection.find(query, {"key": 1, "_id": 0}).sort("key", 1) - # AsyncCursor requires async iteration - result = [] - async for doc in cursor: - result.append(doc["key"]) - return result +from llama_stack.core.storage.kvstore.mongodb.mongodb import * # noqa: F401,F403 diff --git a/src/llama_stack/providers/utils/kvstore/postgres/__init__.py b/src/llama_stack/providers/utils/kvstore/postgres/__init__.py index efbf6299d..55683fdd0 100644 --- a/src/llama_stack/providers/utils/kvstore/postgres/__init__.py +++ b/src/llama_stack/providers/utils/kvstore/postgres/__init__.py @@ -4,4 +4,4 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from .postgres import PostgresKVStoreImpl # noqa: F401 F403 +from llama_stack.core.storage.kvstore.postgres import * # noqa: F401,F403 diff --git a/src/llama_stack/providers/utils/kvstore/postgres/postgres.py b/src/llama_stack/providers/utils/kvstore/postgres/postgres.py index 56d6dbb48..ab7f27ccf 100644 --- a/src/llama_stack/providers/utils/kvstore/postgres/postgres.py +++ b/src/llama_stack/providers/utils/kvstore/postgres/postgres.py @@ -4,111 +4,4 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from datetime import datetime - -import psycopg2 -from psycopg2.extras import DictCursor - -from llama_stack.log import get_logger - -from ..api import KVStore -from ..config import PostgresKVStoreConfig - -log = get_logger(name=__name__, category="providers::utils") - - -class PostgresKVStoreImpl(KVStore): - def __init__(self, config: PostgresKVStoreConfig): - self.config = config - self.conn = None - self.cursor = None - - async def initialize(self) -> None: - try: - self.conn = psycopg2.connect( - host=self.config.host, - port=self.config.port, - database=self.config.db, - user=self.config.user, - password=self.config.password, - sslmode=self.config.ssl_mode, - sslrootcert=self.config.ca_cert_path, - ) - self.conn.autocommit = True - self.cursor = self.conn.cursor(cursor_factory=DictCursor) - - # Create table if it doesn't exist - self.cursor.execute( - f""" - CREATE TABLE IF NOT EXISTS {self.config.table_name} ( - key TEXT PRIMARY KEY, - value TEXT, - expiration TIMESTAMP - ) - """ - ) - except Exception as e: - log.exception("Could not connect to PostgreSQL database server") - raise RuntimeError("Could not connect to PostgreSQL database server") from e - - def _namespaced_key(self, key: str) -> str: - if not self.config.namespace: - return key - return f"{self.config.namespace}:{key}" - - async def set(self, key: str, value: str, expiration: datetime | None = None) -> None: - key = self._namespaced_key(key) - self.cursor.execute( - f""" - INSERT INTO {self.config.table_name} (key, value, expiration) - VALUES (%s, %s, %s) - ON CONFLICT (key) DO UPDATE - SET value = EXCLUDED.value, expiration = EXCLUDED.expiration - """, - (key, value, expiration), - ) - - async def get(self, key: str) -> str | None: - key = self._namespaced_key(key) - self.cursor.execute( - f""" - SELECT value FROM {self.config.table_name} - WHERE key = %s - AND (expiration IS NULL OR expiration > NOW()) - """, - (key,), - ) - result = self.cursor.fetchone() - return result[0] if result else None - - async def delete(self, key: str) -> None: - key = self._namespaced_key(key) - self.cursor.execute( - f"DELETE FROM {self.config.table_name} WHERE key = %s", - (key,), - ) - - async def values_in_range(self, start_key: str, end_key: str) -> list[str]: - start_key = self._namespaced_key(start_key) - end_key = self._namespaced_key(end_key) - - self.cursor.execute( - f""" - SELECT value FROM {self.config.table_name} - WHERE key >= %s AND key < %s - AND (expiration IS NULL OR expiration > NOW()) - ORDER BY key - """, - (start_key, end_key), - ) - return [row[0] for row in self.cursor.fetchall()] - - async def keys_in_range(self, start_key: str, end_key: str) -> list[str]: - start_key = self._namespaced_key(start_key) - end_key = self._namespaced_key(end_key) - - self.cursor.execute( - f"SELECT key FROM {self.config.table_name} WHERE key >= %s AND key < %s", - (start_key, end_key), - ) - return [row[0] for row in self.cursor.fetchall()] +from llama_stack.core.storage.kvstore.postgres.postgres import * # noqa: F401,F403 diff --git a/src/llama_stack/providers/utils/kvstore/redis/__init__.py b/src/llama_stack/providers/utils/kvstore/redis/__init__.py index 94693ca43..6153fa93c 100644 --- a/src/llama_stack/providers/utils/kvstore/redis/__init__.py +++ b/src/llama_stack/providers/utils/kvstore/redis/__init__.py @@ -4,4 +4,4 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from .redis import RedisKVStoreImpl # noqa: F401 +from llama_stack.core.storage.kvstore.redis import * # noqa: F401,F403 diff --git a/src/llama_stack/providers/utils/kvstore/redis/redis.py b/src/llama_stack/providers/utils/kvstore/redis/redis.py index 3d2d956c3..1a3b6f586 100644 --- a/src/llama_stack/providers/utils/kvstore/redis/redis.py +++ b/src/llama_stack/providers/utils/kvstore/redis/redis.py @@ -4,73 +4,4 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from datetime import datetime - -from redis.asyncio import Redis - -from ..api import KVStore -from ..config import RedisKVStoreConfig - - -class RedisKVStoreImpl(KVStore): - def __init__(self, config: RedisKVStoreConfig): - self.config = config - - async def initialize(self) -> None: - self.redis = Redis.from_url(self.config.url) - - def _namespaced_key(self, key: str) -> str: - if not self.config.namespace: - return key - return f"{self.config.namespace}:{key}" - - async def set(self, key: str, value: str, expiration: datetime | None = None) -> None: - key = self._namespaced_key(key) - await self.redis.set(key, value) - if expiration: - await self.redis.expireat(key, expiration) - - async def get(self, key: str) -> str | None: - key = self._namespaced_key(key) - value = await self.redis.get(key) - if value is None: - return None - await self.redis.ttl(key) - return value - - async def delete(self, key: str) -> None: - key = self._namespaced_key(key) - await self.redis.delete(key) - - async def values_in_range(self, start_key: str, end_key: str) -> list[str]: - start_key = self._namespaced_key(start_key) - end_key = self._namespaced_key(end_key) - cursor = 0 - pattern = start_key + "*" # Match all keys starting with start_key prefix - matching_keys = [] - while True: - cursor, keys = await self.redis.scan(cursor, match=pattern, count=1000) - - for key in keys: - key_str = key.decode("utf-8") if isinstance(key, bytes) else key - if start_key <= key_str <= end_key: - matching_keys.append(key) - - if cursor == 0: - break - - # Then fetch all values in a single MGET call - if matching_keys: - values = await self.redis.mget(matching_keys) - return [ - value.decode("utf-8") if isinstance(value, bytes) else value for value in values if value is not None - ] - - return [] - - async def keys_in_range(self, start_key: str, end_key: str) -> list[str]: - """Get all keys in the given range.""" - matching_keys = await self.redis.zrangebylex(self.namespace, f"[{start_key}", f"[{end_key}") - if not matching_keys: - return [] - return [k.decode("utf-8") for k in matching_keys] +from llama_stack.core.storage.kvstore.redis.redis import * # noqa: F401,F403 diff --git a/src/llama_stack/providers/utils/kvstore/sqlite/__init__.py b/src/llama_stack/providers/utils/kvstore/sqlite/__init__.py index 03bc53c24..2272543a4 100644 --- a/src/llama_stack/providers/utils/kvstore/sqlite/__init__.py +++ b/src/llama_stack/providers/utils/kvstore/sqlite/__init__.py @@ -4,4 +4,4 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from .sqlite import SqliteKVStoreImpl # noqa: F401 +from llama_stack.core.storage.kvstore.sqlite import * # noqa: F401,F403 diff --git a/src/llama_stack/providers/utils/kvstore/sqlite/config.py b/src/llama_stack/providers/utils/kvstore/sqlite/config.py index 0f8fa0a95..162a89b59 100644 --- a/src/llama_stack/providers/utils/kvstore/sqlite/config.py +++ b/src/llama_stack/providers/utils/kvstore/sqlite/config.py @@ -4,17 +4,4 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from pydantic import BaseModel, Field - -from llama_stack_api import json_schema_type - - -@json_schema_type -class SqliteControlPlaneConfig(BaseModel): - db_path: str = Field( - description="File path for the sqlite database", - ) - table_name: str = Field( - default="llamastack_control_plane", - description="Table into which all the keys will be placed", - ) +from llama_stack.core.storage.kvstore.sqlite.config import * # noqa: F401,F403 diff --git a/src/llama_stack/providers/utils/kvstore/sqlite/sqlite.py b/src/llama_stack/providers/utils/kvstore/sqlite/sqlite.py index a9a7a1304..a9ad4ee23 100644 --- a/src/llama_stack/providers/utils/kvstore/sqlite/sqlite.py +++ b/src/llama_stack/providers/utils/kvstore/sqlite/sqlite.py @@ -4,171 +4,4 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import os -from datetime import datetime - -import aiosqlite - -from llama_stack.log import get_logger - -from ..api import KVStore -from ..config import SqliteKVStoreConfig - -logger = get_logger(name=__name__, category="providers::utils") - - -class SqliteKVStoreImpl(KVStore): - def __init__(self, config: SqliteKVStoreConfig): - self.db_path = config.db_path - self.table_name = "kvstore" - self._conn: aiosqlite.Connection | None = None - - def __str__(self): - return f"SqliteKVStoreImpl(db_path={self.db_path}, table_name={self.table_name})" - - def _is_memory_db(self) -> bool: - """Check if this is an in-memory database.""" - return self.db_path == ":memory:" or "mode=memory" in self.db_path - - async def initialize(self): - # Skip directory creation for in-memory databases and file: URIs - if not self._is_memory_db() and not self.db_path.startswith("file:"): - db_dir = os.path.dirname(self.db_path) - if db_dir: # Only create if there's a directory component - os.makedirs(db_dir, exist_ok=True) - - # Only use persistent connection for in-memory databases - # File-based databases use connection-per-operation to avoid hangs - if self._is_memory_db(): - self._conn = await aiosqlite.connect(self.db_path) - await self._conn.execute( - f""" - CREATE TABLE IF NOT EXISTS {self.table_name} ( - key TEXT PRIMARY KEY, - value TEXT, - expiration TIMESTAMP - ) - """ - ) - await self._conn.commit() - else: - # For file-based databases, just create the table - async with aiosqlite.connect(self.db_path) as db: - await db.execute( - f""" - CREATE TABLE IF NOT EXISTS {self.table_name} ( - key TEXT PRIMARY KEY, - value TEXT, - expiration TIMESTAMP - ) - """ - ) - await db.commit() - - async def shutdown(self): - """Close the persistent connection (only for in-memory databases).""" - if self._conn: - await self._conn.close() - self._conn = None - - async def set(self, key: str, value: str, expiration: datetime | None = None) -> None: - if self._conn: - # In-memory database with persistent connection - await self._conn.execute( - f"INSERT OR REPLACE INTO {self.table_name} (key, value, expiration) VALUES (?, ?, ?)", - (key, value, expiration), - ) - await self._conn.commit() - else: - # File-based database with connection per operation - async with aiosqlite.connect(self.db_path) as db: - await db.execute( - f"INSERT OR REPLACE INTO {self.table_name} (key, value, expiration) VALUES (?, ?, ?)", - (key, value, expiration), - ) - await db.commit() - - async def get(self, key: str) -> str | None: - if self._conn: - # In-memory database with persistent connection - async with self._conn.execute( - f"SELECT value, expiration FROM {self.table_name} WHERE key = ?", (key,) - ) as cursor: - row = await cursor.fetchone() - if row is None: - return None - value, expiration = row - if not isinstance(value, str): - logger.warning(f"Expected string value for key {key}, got {type(value)}, returning None") - return None - return value - else: - # File-based database with connection per operation - async with aiosqlite.connect(self.db_path) as db: - async with db.execute( - f"SELECT value, expiration FROM {self.table_name} WHERE key = ?", (key,) - ) as cursor: - row = await cursor.fetchone() - if row is None: - return None - value, expiration = row - if not isinstance(value, str): - logger.warning(f"Expected string value for key {key}, got {type(value)}, returning None") - return None - return value - - async def delete(self, key: str) -> None: - if self._conn: - # In-memory database with persistent connection - await self._conn.execute(f"DELETE FROM {self.table_name} WHERE key = ?", (key,)) - await self._conn.commit() - else: - # File-based database with connection per operation - async with aiosqlite.connect(self.db_path) as db: - await db.execute(f"DELETE FROM {self.table_name} WHERE key = ?", (key,)) - await db.commit() - - async def values_in_range(self, start_key: str, end_key: str) -> list[str]: - if self._conn: - # In-memory database with persistent connection - async with self._conn.execute( - f"SELECT key, value, expiration FROM {self.table_name} WHERE key >= ? AND key <= ?", - (start_key, end_key), - ) as cursor: - result = [] - async for row in cursor: - _, value, _ = row - result.append(value) - return result - else: - # File-based database with connection per operation - async with aiosqlite.connect(self.db_path) as db: - async with db.execute( - f"SELECT key, value, expiration FROM {self.table_name} WHERE key >= ? AND key <= ?", - (start_key, end_key), - ) as cursor: - result = [] - async for row in cursor: - _, value, _ = row - result.append(value) - return result - - async def keys_in_range(self, start_key: str, end_key: str) -> list[str]: - """Get all keys in the given range.""" - if self._conn: - # In-memory database with persistent connection - cursor = await self._conn.execute( - f"SELECT key FROM {self.table_name} WHERE key >= ? AND key <= ?", - (start_key, end_key), - ) - rows = await cursor.fetchall() - return [row[0] for row in rows] - else: - # File-based database with connection per operation - async with aiosqlite.connect(self.db_path) as db: - cursor = await db.execute( - f"SELECT key FROM {self.table_name} WHERE key >= ? AND key <= ?", - (start_key, end_key), - ) - rows = await cursor.fetchall() - return [row[0] for row in rows] +from llama_stack.core.storage.kvstore.sqlite.sqlite import * # noqa: F401,F403 diff --git a/src/llama_stack/providers/utils/memory/openai_vector_store_mixin.py b/src/llama_stack/providers/utils/memory/openai_vector_store_mixin.py index 540ff5940..bbfd60e25 100644 --- a/src/llama_stack/providers/utils/memory/openai_vector_store_mixin.py +++ b/src/llama_stack/providers/utils/memory/openai_vector_store_mixin.py @@ -17,7 +17,6 @@ from pydantic import TypeAdapter from llama_stack.core.id_generation import generate_object_id from llama_stack.log import get_logger -from llama_stack.providers.utils.kvstore.api import KVStore from llama_stack.providers.utils.memory.vector_store import ( ChunkForDeletion, content_from_data_and_mime_type, @@ -53,6 +52,7 @@ from llama_stack_api import ( VectorStoreSearchResponse, VectorStoreSearchResponsePage, ) +from llama_stack_api.internal.kvstore import KVStore EMBEDDING_DIMENSION = 768 diff --git a/src/llama_stack/providers/utils/responses/responses_store.py b/src/llama_stack/providers/utils/responses/responses_store.py index f6e7c435d..d12b999fb 100644 --- a/src/llama_stack/providers/utils/responses/responses_store.py +++ b/src/llama_stack/providers/utils/responses/responses_store.py @@ -18,9 +18,9 @@ from llama_stack_api import ( Order, ) -from ..sqlstore.api import ColumnDefinition, ColumnType -from ..sqlstore.authorized_sqlstore import AuthorizedSqlStore -from ..sqlstore.sqlstore import sqlstore_impl +from llama_stack.core.storage.sqlstore.api import ColumnDefinition, ColumnType +from llama_stack.core.storage.sqlstore.authorized_sqlstore import AuthorizedSqlStore +from llama_stack.core.storage.sqlstore.sqlstore import sqlstore_impl logger = get_logger(name=__name__, category="openai_responses") diff --git a/src/llama_stack/providers/utils/sqlstore/__init__.py b/src/llama_stack/providers/utils/sqlstore/__init__.py index 756f351d8..233d828c7 100644 --- a/src/llama_stack/providers/utils/sqlstore/__init__.py +++ b/src/llama_stack/providers/utils/sqlstore/__init__.py @@ -3,3 +3,5 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. + +from llama_stack.core.storage.sqlstore import * # noqa: F401,F403 diff --git a/src/llama_stack/providers/utils/sqlstore/api.py b/src/llama_stack/providers/utils/sqlstore/api.py index 708fc7095..241601b80 100644 --- a/src/llama_stack/providers/utils/sqlstore/api.py +++ b/src/llama_stack/providers/utils/sqlstore/api.py @@ -4,137 +4,4 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from collections.abc import Mapping, Sequence -from enum import Enum -from typing import Any, Literal, Protocol - -from pydantic import BaseModel - -from llama_stack_api import PaginatedResponse - - -class ColumnType(Enum): - INTEGER = "INTEGER" - STRING = "STRING" - TEXT = "TEXT" - FLOAT = "FLOAT" - BOOLEAN = "BOOLEAN" - JSON = "JSON" - DATETIME = "DATETIME" - - -class ColumnDefinition(BaseModel): - type: ColumnType - primary_key: bool = False - nullable: bool = True - default: Any = None - - -class SqlStore(Protocol): - """ - A protocol for a SQL store. - """ - - async def create_table(self, table: str, schema: Mapping[str, ColumnType | ColumnDefinition]) -> None: - """ - Create a table. - """ - pass - - async def insert(self, table: str, data: Mapping[str, Any] | Sequence[Mapping[str, Any]]) -> None: - """ - Insert a row or batch of rows into a table. - """ - pass - - async def upsert( - self, - table: str, - data: Mapping[str, Any], - conflict_columns: list[str], - update_columns: list[str] | None = None, - ) -> None: - """ - Insert a row and update specified columns when conflicts occur. - """ - pass - - async def fetch_all( - self, - table: str, - where: Mapping[str, Any] | None = None, - where_sql: str | None = None, - limit: int | None = None, - order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None, - cursor: tuple[str, str] | None = None, - ) -> PaginatedResponse: - """ - Fetch all rows from a table with optional cursor-based pagination. - - :param table: The table name - :param where: Simple key-value WHERE conditions - :param where_sql: Raw SQL WHERE clause for complex queries - :param limit: Maximum number of records to return - :param order_by: List of (column, order) tuples for sorting - :param cursor: Tuple of (key_column, cursor_id) for pagination (None for first page) - Requires order_by with exactly one column when used - :return: PaginatedResult with data and has_more flag - - Note: Cursor pagination only supports single-column ordering for simplicity. - Multi-column ordering is allowed without cursor but will raise an error with cursor. - """ - pass - - async def fetch_one( - self, - table: str, - where: Mapping[str, Any] | None = None, - where_sql: str | None = None, - order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None, - ) -> dict[str, Any] | None: - """ - Fetch one row from a table. - """ - pass - - async def update( - self, - table: str, - data: Mapping[str, Any], - where: Mapping[str, Any], - ) -> None: - """ - Update a row in a table. - """ - pass - - async def delete( - self, - table: str, - where: Mapping[str, Any], - ) -> None: - """ - Delete a row from a table. - """ - pass - - async def add_column_if_not_exists( - self, - table: str, - column_name: str, - column_type: ColumnType, - nullable: bool = True, - ) -> None: - """ - Add a column to an existing table if the column doesn't already exist. - - This is useful for table migrations when adding new functionality. - If the table doesn't exist, this method should do nothing. - If the column already exists, this method should do nothing. - - :param table: Table name - :param column_name: Name of the column to add - :param column_type: Type of the column to add - :param nullable: Whether the column should be nullable (default: True) - """ - pass +from llama_stack.core.storage.sqlstore.api import * # noqa: F401,F403 diff --git a/src/llama_stack/providers/utils/sqlstore/authorized_sqlstore.py b/src/llama_stack/providers/utils/sqlstore/authorized_sqlstore.py index ba95dd120..6f69ef680 100644 --- a/src/llama_stack/providers/utils/sqlstore/authorized_sqlstore.py +++ b/src/llama_stack/providers/utils/sqlstore/authorized_sqlstore.py @@ -4,333 +4,4 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from collections.abc import Mapping, Sequence -from typing import Any, Literal - -from llama_stack.core.access_control.access_control import default_policy, is_action_allowed -from llama_stack.core.access_control.conditions import ProtectedResource -from llama_stack.core.access_control.datatypes import AccessRule, Action, Scope -from llama_stack.core.datatypes import User -from llama_stack.core.request_headers import get_authenticated_user -from llama_stack.core.storage.datatypes import StorageBackendType -from llama_stack.log import get_logger - -from .api import ColumnDefinition, ColumnType, PaginatedResponse, SqlStore - -logger = get_logger(name=__name__, category="providers::utils") - -# Hardcoded copy of the default policy that our SQL filtering implements -# WARNING: If default_policy() changes, this constant must be updated accordingly -# or SQL filtering will fall back to conservative mode (safe but less performant) -# -# This policy represents: "Permit all actions when user is in owners list for ALL attribute categories" -# The corresponding SQL logic is implemented in _build_default_policy_where_clause(): -# - Public records (no access_attributes) are always accessible -# - Records with access_attributes require user to match ALL categories that exist in the resource -# - Missing categories in the resource are treated as "no restriction" (allow) -# - Within each category, user needs ANY matching value (OR logic) -# - Between categories, user needs ALL categories to match (AND logic) -SQL_OPTIMIZED_POLICY = [ - AccessRule( - permit=Scope(actions=list(Action)), - when=["user in owners roles", "user in owners teams", "user in owners projects", "user in owners namespaces"], - ), -] - - -def _enhance_item_with_access_control(item: Mapping[str, Any], current_user: User | None) -> Mapping[str, Any]: - """Add access control attributes to a data item.""" - enhanced = dict(item) - if current_user: - enhanced["owner_principal"] = current_user.principal - enhanced["access_attributes"] = current_user.attributes - else: - # IMPORTANT: Use empty string and null value (not None) to match public access filter - # The public access filter in _get_public_access_conditions() expects: - # - owner_principal = '' (empty string) - # - access_attributes = null (JSON null, which serializes to the string 'null') - # Setting them to None (SQL NULL) will cause rows to be filtered out on read. - enhanced["owner_principal"] = "" - enhanced["access_attributes"] = None # Pydantic/JSON will serialize this as JSON null - return enhanced - - -class SqlRecord(ProtectedResource): - def __init__(self, record_id: str, table_name: str, owner: User): - self.type = f"sql_record::{table_name}" - self.identifier = record_id - self.owner = owner - - -class AuthorizedSqlStore: - """ - Authorization layer for SqlStore that provides access control functionality. - - This class composes a base SqlStore and adds authorization methods that handle - access control policies, user attribute capture, and SQL filtering optimization. - """ - - def __init__(self, sql_store: SqlStore, policy: list[AccessRule]): - """ - Initialize the authorization layer. - - :param sql_store: Base SqlStore implementation to wrap - :param policy: Access control policy to use for authorization - """ - self.sql_store = sql_store - self.policy = policy - self._detect_database_type() - self._validate_sql_optimized_policy() - - def _detect_database_type(self) -> None: - """Detect the database type from the underlying SQL store.""" - if not hasattr(self.sql_store, "config"): - raise ValueError("SqlStore must have a config attribute to be used with AuthorizedSqlStore") - - self.database_type = self.sql_store.config.type.value - if self.database_type not in [StorageBackendType.SQL_POSTGRES.value, StorageBackendType.SQL_SQLITE.value]: - raise ValueError(f"Unsupported database type: {self.database_type}") - - def _validate_sql_optimized_policy(self) -> None: - """Validate that SQL_OPTIMIZED_POLICY matches the actual default_policy(). - - This ensures that if default_policy() changes, we detect the mismatch and - can update our SQL filtering logic accordingly. - """ - actual_default = default_policy() - - if SQL_OPTIMIZED_POLICY != actual_default: - logger.warning( - f"SQL_OPTIMIZED_POLICY does not match default_policy(). " - f"SQL filtering will use conservative mode. " - f"Expected: {SQL_OPTIMIZED_POLICY}, Got: {actual_default}", - ) - - async def create_table(self, table: str, schema: Mapping[str, ColumnType | ColumnDefinition]) -> None: - """Create a table with built-in access control support.""" - - enhanced_schema = dict(schema) - if "access_attributes" not in enhanced_schema: - enhanced_schema["access_attributes"] = ColumnType.JSON - if "owner_principal" not in enhanced_schema: - enhanced_schema["owner_principal"] = ColumnType.STRING - - await self.sql_store.create_table(table, enhanced_schema) - await self.sql_store.add_column_if_not_exists(table, "access_attributes", ColumnType.JSON) - await self.sql_store.add_column_if_not_exists(table, "owner_principal", ColumnType.STRING) - - async def insert(self, table: str, data: Mapping[str, Any] | Sequence[Mapping[str, Any]]) -> None: - """Insert a row or batch of rows with automatic access control attribute capture.""" - current_user = get_authenticated_user() - enhanced_data: Mapping[str, Any] | Sequence[Mapping[str, Any]] - if isinstance(data, Mapping): - enhanced_data = _enhance_item_with_access_control(data, current_user) - else: - enhanced_data = [_enhance_item_with_access_control(item, current_user) for item in data] - await self.sql_store.insert(table, enhanced_data) - - async def upsert( - self, - table: str, - data: Mapping[str, Any], - conflict_columns: list[str], - update_columns: list[str] | None = None, - ) -> None: - """Upsert a row with automatic access control attribute capture.""" - current_user = get_authenticated_user() - enhanced_data = _enhance_item_with_access_control(data, current_user) - await self.sql_store.upsert( - table=table, - data=enhanced_data, - conflict_columns=conflict_columns, - update_columns=update_columns, - ) - - async def fetch_all( - self, - table: str, - where: Mapping[str, Any] | None = None, - limit: int | None = None, - order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None, - cursor: tuple[str, str] | None = None, - ) -> PaginatedResponse: - """Fetch all rows with automatic access control filtering.""" - access_where = self._build_access_control_where_clause(self.policy) - rows = await self.sql_store.fetch_all( - table=table, - where=where, - where_sql=access_where, - limit=limit, - order_by=order_by, - cursor=cursor, - ) - - current_user = get_authenticated_user() - filtered_rows = [] - - for row in rows.data: - stored_access_attrs = row.get("access_attributes") - stored_owner_principal = row.get("owner_principal") or "" - - record_id = row.get("id", "unknown") - sql_record = SqlRecord( - str(record_id), table, User(principal=stored_owner_principal, attributes=stored_access_attrs) - ) - - if is_action_allowed(self.policy, Action.READ, sql_record, current_user): - filtered_rows.append(row) - - return PaginatedResponse( - data=filtered_rows, - has_more=rows.has_more, - ) - - async def fetch_one( - self, - table: str, - where: Mapping[str, Any] | None = None, - order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None, - ) -> dict[str, Any] | None: - """Fetch one row with automatic access control checking.""" - results = await self.fetch_all( - table=table, - where=where, - limit=1, - order_by=order_by, - ) - - return results.data[0] if results.data else None - - async def update(self, table: str, data: Mapping[str, Any], where: Mapping[str, Any]) -> None: - """Update rows with automatic access control attribute capture.""" - enhanced_data = dict(data) - - current_user = get_authenticated_user() - if current_user: - enhanced_data["owner_principal"] = current_user.principal - enhanced_data["access_attributes"] = current_user.attributes - else: - # IMPORTANT: Use empty string for owner_principal to match public access filter - enhanced_data["owner_principal"] = "" - enhanced_data["access_attributes"] = None # Will serialize as JSON null - - await self.sql_store.update(table, enhanced_data, where) - - async def delete(self, table: str, where: Mapping[str, Any]) -> None: - """Delete rows with automatic access control filtering.""" - await self.sql_store.delete(table, where) - - def _build_access_control_where_clause(self, policy: list[AccessRule]) -> str: - """Build SQL WHERE clause for access control filtering. - - Only applies SQL filtering for the default policy to ensure correctness. - For custom policies, uses conservative filtering to avoid blocking legitimate access. - """ - current_user = get_authenticated_user() - - if not policy or policy == SQL_OPTIMIZED_POLICY: - return self._build_default_policy_where_clause(current_user) - else: - return self._build_conservative_where_clause() - - def _json_extract(self, column: str, path: str) -> str: - """Extract JSON value (keeping JSON type). - - Args: - column: The JSON column name - path: The JSON path (e.g., 'roles', 'teams') - - Returns: - SQL expression to extract JSON value - """ - if self.database_type == StorageBackendType.SQL_POSTGRES.value: - return f"{column}->'{path}'" - elif self.database_type == StorageBackendType.SQL_SQLITE.value: - return f"JSON_EXTRACT({column}, '$.{path}')" - else: - raise ValueError(f"Unsupported database type: {self.database_type}") - - def _json_extract_text(self, column: str, path: str) -> str: - """Extract JSON value as text. - - Args: - column: The JSON column name - path: The JSON path (e.g., 'roles', 'teams') - - Returns: - SQL expression to extract JSON value as text - """ - if self.database_type == StorageBackendType.SQL_POSTGRES.value: - return f"{column}->>'{path}'" - elif self.database_type == StorageBackendType.SQL_SQLITE.value: - return f"JSON_EXTRACT({column}, '$.{path}')" - else: - raise ValueError(f"Unsupported database type: {self.database_type}") - - def _get_public_access_conditions(self) -> list[str]: - """Get the SQL conditions for public access. - - Public records are those with: - - owner_principal = '' (empty string) - - access_attributes is either SQL NULL or JSON null - - Note: Different databases serialize None differently: - - SQLite: None → JSON null (text = 'null') - - Postgres: None → SQL NULL (IS NULL) - """ - conditions = ["owner_principal = ''"] - if self.database_type == StorageBackendType.SQL_POSTGRES.value: - # Accept both SQL NULL and JSON null for Postgres compatibility - # This handles both old rows (SQL NULL) and new rows (JSON null) - conditions.append("(access_attributes IS NULL OR access_attributes::text = 'null')") - elif self.database_type == StorageBackendType.SQL_SQLITE.value: - # SQLite serializes None as JSON null - conditions.append("(access_attributes IS NULL OR access_attributes = 'null')") - else: - raise ValueError(f"Unsupported database type: {self.database_type}") - return conditions - - def _build_default_policy_where_clause(self, current_user: User | None) -> str: - """Build SQL WHERE clause for the default policy. - - Default policy: permit all actions when user in owners [roles, teams, projects, namespaces] - This means user must match ALL attribute categories that exist in the resource. - """ - base_conditions = self._get_public_access_conditions() - user_attr_conditions = [] - - if current_user and current_user.attributes: - for attr_key, user_values in current_user.attributes.items(): - if user_values: - value_conditions = [] - for value in user_values: - # Check if JSON array contains the value - escaped_value = value.replace("'", "''") - json_text = self._json_extract_text("access_attributes", attr_key) - value_conditions.append(f"({json_text} LIKE '%\"{escaped_value}\"%')") - - if value_conditions: - # Check if the category is missing (NULL) - category_missing = f"{self._json_extract('access_attributes', attr_key)} IS NULL" - user_matches_category = f"({' OR '.join(value_conditions)})" - user_attr_conditions.append(f"({category_missing} OR {user_matches_category})") - - if user_attr_conditions: - all_requirements_met = f"({' AND '.join(user_attr_conditions)})" - base_conditions.append(all_requirements_met) - - return f"({' OR '.join(base_conditions)})" - - def _build_conservative_where_clause(self) -> str: - """Conservative SQL filtering for custom policies. - - Only filters records we're 100% certain would be denied by any reasonable policy. - """ - current_user = get_authenticated_user() - - if not current_user: - # Only allow public records - base_conditions = self._get_public_access_conditions() - return f"({' OR '.join(base_conditions)})" - - return "1=1" +from llama_stack.core.storage.sqlstore.authorized_sqlstore import * # noqa: F401,F403 diff --git a/src/llama_stack/providers/utils/sqlstore/sqlalchemy_sqlstore.py b/src/llama_stack/providers/utils/sqlstore/sqlalchemy_sqlstore.py index 10009d396..ee6f2ad3e 100644 --- a/src/llama_stack/providers/utils/sqlstore/sqlalchemy_sqlstore.py +++ b/src/llama_stack/providers/utils/sqlstore/sqlalchemy_sqlstore.py @@ -3,372 +3,5 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from collections.abc import Mapping, Sequence -from typing import Any, Literal, cast -from sqlalchemy import ( - JSON, - Boolean, - Column, - DateTime, - Float, - Integer, - MetaData, - String, - Table, - Text, - event, - inspect, - select, - text, -) -from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine -from sqlalchemy.ext.asyncio.engine import AsyncEngine -from sqlalchemy.sql.elements import ColumnElement - -from llama_stack.core.storage.datatypes import SqlAlchemySqlStoreConfig -from llama_stack.log import get_logger -from llama_stack_api import PaginatedResponse - -from .api import ColumnDefinition, ColumnType, SqlStore - -logger = get_logger(name=__name__, category="providers::utils") - -TYPE_MAPPING: dict[ColumnType, Any] = { - ColumnType.INTEGER: Integer, - ColumnType.STRING: String, - ColumnType.FLOAT: Float, - ColumnType.BOOLEAN: Boolean, - ColumnType.DATETIME: DateTime, - ColumnType.TEXT: Text, - ColumnType.JSON: JSON, -} - - -def _build_where_expr(column: ColumnElement, value: Any) -> ColumnElement: - """Return a SQLAlchemy expression for a where condition. - - `value` may be a simple scalar (equality) or a mapping like {">": 123}. - The returned expression is a SQLAlchemy ColumnElement usable in query.where(...). - """ - if isinstance(value, Mapping): - if len(value) != 1: - raise ValueError(f"Operator mapping must have a single operator, got: {value}") - op, operand = next(iter(value.items())) - if op == "==" or op == "=": - return cast(ColumnElement[Any], column == operand) - if op == ">": - return cast(ColumnElement[Any], column > operand) - if op == "<": - return cast(ColumnElement[Any], column < operand) - if op == ">=": - return cast(ColumnElement[Any], column >= operand) - if op == "<=": - return cast(ColumnElement[Any], column <= operand) - raise ValueError(f"Unsupported operator '{op}' in where mapping") - return cast(ColumnElement[Any], column == value) - - -class SqlAlchemySqlStoreImpl(SqlStore): - def __init__(self, config: SqlAlchemySqlStoreConfig): - self.config = config - self._is_sqlite_backend = "sqlite" in self.config.engine_str - self.async_session = async_sessionmaker(self.create_engine()) - self.metadata = MetaData() - - def create_engine(self) -> AsyncEngine: - # Configure connection args for better concurrency support - connect_args = {} - if self._is_sqlite_backend: - # SQLite-specific optimizations for concurrent access - # With WAL mode, most locks resolve in milliseconds, but allow up to 5s for edge cases - connect_args["timeout"] = 5.0 - connect_args["check_same_thread"] = False # Allow usage across asyncio tasks - - engine = create_async_engine( - self.config.engine_str, - pool_pre_ping=True, - connect_args=connect_args, - ) - - # Enable WAL mode for SQLite to support concurrent readers and writers - if self._is_sqlite_backend: - - @event.listens_for(engine.sync_engine, "connect") - def set_sqlite_pragma(dbapi_conn, connection_record): - cursor = dbapi_conn.cursor() - # Enable Write-Ahead Logging for better concurrency - cursor.execute("PRAGMA journal_mode=WAL") - # Set busy timeout to 5 seconds (retry instead of immediate failure) - # With WAL mode, locks should be brief; if we hit 5s there's a bigger issue - cursor.execute("PRAGMA busy_timeout=5000") - # Use NORMAL synchronous mode for better performance (still safe with WAL) - cursor.execute("PRAGMA synchronous=NORMAL") - cursor.close() - - return engine - - async def create_table( - self, - table: str, - schema: Mapping[str, ColumnType | ColumnDefinition], - ) -> None: - if not schema: - raise ValueError(f"No columns defined for table '{table}'.") - - sqlalchemy_columns: list[Column] = [] - - for col_name, col_props in schema.items(): - col_type = None - is_primary_key = False - is_nullable = True - - if isinstance(col_props, ColumnType): - col_type = col_props - elif isinstance(col_props, ColumnDefinition): - col_type = col_props.type - is_primary_key = col_props.primary_key - is_nullable = col_props.nullable - - sqlalchemy_type = TYPE_MAPPING.get(col_type) - if not sqlalchemy_type: - raise ValueError(f"Unsupported column type '{col_type}' for column '{col_name}'.") - - sqlalchemy_columns.append( - Column(col_name, sqlalchemy_type, primary_key=is_primary_key, nullable=is_nullable) - ) - - if table not in self.metadata.tables: - sqlalchemy_table = Table(table, self.metadata, *sqlalchemy_columns) - else: - sqlalchemy_table = self.metadata.tables[table] - - engine = self.create_engine() - async with engine.begin() as conn: - await conn.run_sync(self.metadata.create_all, tables=[sqlalchemy_table], checkfirst=True) - - async def insert(self, table: str, data: Mapping[str, Any] | Sequence[Mapping[str, Any]]) -> None: - async with self.async_session() as session: - await session.execute(self.metadata.tables[table].insert(), data) - await session.commit() - - async def upsert( - self, - table: str, - data: Mapping[str, Any], - conflict_columns: list[str], - update_columns: list[str] | None = None, - ) -> None: - table_obj = self.metadata.tables[table] - dialect_insert = self._get_dialect_insert(table_obj) - insert_stmt = dialect_insert.values(**data) - - if update_columns is None: - update_columns = [col for col in data.keys() if col not in conflict_columns] - - update_mapping = {col: getattr(insert_stmt.excluded, col) for col in update_columns} - conflict_cols = [table_obj.c[col] for col in conflict_columns] - - stmt = insert_stmt.on_conflict_do_update(index_elements=conflict_cols, set_=update_mapping) - - async with self.async_session() as session: - await session.execute(stmt) - await session.commit() - - async def fetch_all( - self, - table: str, - where: Mapping[str, Any] | None = None, - where_sql: str | None = None, - limit: int | None = None, - order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None, - cursor: tuple[str, str] | None = None, - ) -> PaginatedResponse: - async with self.async_session() as session: - table_obj = self.metadata.tables[table] - query = select(table_obj) - - if where: - for key, value in where.items(): - query = query.where(_build_where_expr(table_obj.c[key], value)) - - if where_sql: - query = query.where(text(where_sql)) - - # Handle cursor-based pagination - if cursor: - # Validate cursor tuple format - if not isinstance(cursor, tuple) or len(cursor) != 2: - raise ValueError(f"Cursor must be a tuple of (key_column, cursor_id), got: {cursor}") - - # Require order_by for cursor pagination - if not order_by: - raise ValueError("order_by is required when using cursor pagination") - - # Only support single-column ordering for cursor pagination - if len(order_by) != 1: - raise ValueError( - f"Cursor pagination only supports single-column ordering, got {len(order_by)} columns" - ) - - cursor_key_column, cursor_id = cursor - order_column, order_direction = order_by[0] - - # Verify cursor_key_column exists - if cursor_key_column not in table_obj.c: - raise ValueError(f"Cursor key column '{cursor_key_column}' not found in table '{table}'") - - # Get cursor value for the order column - cursor_query = select(table_obj.c[order_column]).where(table_obj.c[cursor_key_column] == cursor_id) - cursor_result = await session.execute(cursor_query) - cursor_row = cursor_result.fetchone() - - if not cursor_row: - raise ValueError(f"Record with {cursor_key_column}='{cursor_id}' not found in table '{table}'") - - cursor_value = cursor_row[0] - - # Apply cursor condition based on sort direction - if order_direction == "desc": - query = query.where(table_obj.c[order_column] < cursor_value) - else: - query = query.where(table_obj.c[order_column] > cursor_value) - - # Apply ordering - if order_by: - if not isinstance(order_by, list): - raise ValueError( - f"order_by must be a list of tuples (column, order={['asc', 'desc']}), got {order_by}" - ) - for order in order_by: - if not isinstance(order, tuple): - raise ValueError( - f"order_by must be a list of tuples (column, order={['asc', 'desc']}), got {order_by}" - ) - name, order_type = order - if name not in table_obj.c: - raise ValueError(f"Column '{name}' not found in table '{table}'") - if order_type == "asc": - query = query.order_by(table_obj.c[name].asc()) - elif order_type == "desc": - query = query.order_by(table_obj.c[name].desc()) - else: - raise ValueError(f"Invalid order '{order_type}' for column '{name}'") - - # Fetch limit + 1 to determine has_more - fetch_limit = limit - if limit: - fetch_limit = limit + 1 - - if fetch_limit: - query = query.limit(fetch_limit) - - result = await session.execute(query) - # Iterate directly - if no rows, list comprehension yields empty list - rows = [dict(row._mapping) for row in result] - - # Always return pagination result - has_more = False - if limit and len(rows) > limit: - has_more = True - rows = rows[:limit] - - return PaginatedResponse(data=rows, has_more=has_more) - - async def fetch_one( - self, - table: str, - where: Mapping[str, Any] | None = None, - where_sql: str | None = None, - order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None, - ) -> dict[str, Any] | None: - result = await self.fetch_all(table, where, where_sql, limit=1, order_by=order_by) - if not result.data: - return None - return result.data[0] - - async def update( - self, - table: str, - data: Mapping[str, Any], - where: Mapping[str, Any], - ) -> None: - if not where: - raise ValueError("where is required for update") - - async with self.async_session() as session: - stmt = self.metadata.tables[table].update() - for key, value in where.items(): - stmt = stmt.where(_build_where_expr(self.metadata.tables[table].c[key], value)) - await session.execute(stmt, data) - await session.commit() - - async def delete(self, table: str, where: Mapping[str, Any]) -> None: - if not where: - raise ValueError("where is required for delete") - - async with self.async_session() as session: - stmt = self.metadata.tables[table].delete() - for key, value in where.items(): - stmt = stmt.where(_build_where_expr(self.metadata.tables[table].c[key], value)) - await session.execute(stmt) - await session.commit() - - async def add_column_if_not_exists( - self, - table: str, - column_name: str, - column_type: ColumnType, - nullable: bool = True, - ) -> None: - """Add a column to an existing table if the column doesn't already exist.""" - engine = self.create_engine() - - try: - async with engine.begin() as conn: - - def check_column_exists(sync_conn): - inspector = inspect(sync_conn) - - table_names = inspector.get_table_names() - if table not in table_names: - return False, False # table doesn't exist, column doesn't exist - - existing_columns = inspector.get_columns(table) - column_names = [col["name"] for col in existing_columns] - - return True, column_name in column_names # table exists, column exists or not - - table_exists, column_exists = await conn.run_sync(check_column_exists) - if not table_exists or column_exists: - return - - sqlalchemy_type = TYPE_MAPPING.get(column_type) - if not sqlalchemy_type: - raise ValueError(f"Unsupported column type '{column_type}' for column '{column_name}'.") - - # Create the ALTER TABLE statement - # Note: We need to get the dialect-specific type name - dialect = engine.dialect - type_impl = sqlalchemy_type() - compiled_type = type_impl.compile(dialect=dialect) - - nullable_clause = "" if nullable else " NOT NULL" - add_column_sql = text(f"ALTER TABLE {table} ADD COLUMN {column_name} {compiled_type}{nullable_clause}") - - await conn.execute(add_column_sql) - except Exception as e: - # If any error occurs during migration, log it but don't fail - # The table creation will handle adding the column - logger.error(f"Error adding column {column_name} to table {table}: {e}") - pass - - def _get_dialect_insert(self, table: Table): - if self._is_sqlite_backend: - from sqlalchemy.dialects.sqlite import insert as sqlite_insert - - return sqlite_insert(table) - else: - from sqlalchemy.dialects.postgresql import insert as pg_insert - - return pg_insert(table) +from llama_stack.core.storage.sqlstore.sqlalchemy_sqlstore import * # noqa: F401,F403 diff --git a/src/llama_stack/providers/utils/sqlstore/sqlstore.py b/src/llama_stack/providers/utils/sqlstore/sqlstore.py index 9409b7d00..1628523a7 100644 --- a/src/llama_stack/providers/utils/sqlstore/sqlstore.py +++ b/src/llama_stack/providers/utils/sqlstore/sqlstore.py @@ -4,85 +4,4 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from threading import Lock -from typing import Annotated, cast - -from pydantic import Field - -from llama_stack.core.storage.datatypes import ( - PostgresSqlStoreConfig, - SqliteSqlStoreConfig, - SqlStoreReference, - StorageBackendConfig, - StorageBackendType, -) - -from .api import SqlStore - -sql_store_pip_packages = ["sqlalchemy[asyncio]", "aiosqlite", "asyncpg"] - -_SQLSTORE_BACKENDS: dict[str, StorageBackendConfig] = {} -_SQLSTORE_INSTANCES: dict[str, SqlStore] = {} -_SQLSTORE_LOCKS: dict[str, Lock] = {} - - -SqlStoreConfig = Annotated[ - SqliteSqlStoreConfig | PostgresSqlStoreConfig, - Field(discriminator="type"), -] - - -def get_pip_packages(store_config: dict | SqlStoreConfig) -> list[str]: - """Get pip packages for SQL store config, handling both dict and object cases.""" - if isinstance(store_config, dict): - store_type = store_config.get("type") - if store_type == StorageBackendType.SQL_SQLITE.value: - return SqliteSqlStoreConfig.pip_packages() - elif store_type == StorageBackendType.SQL_POSTGRES.value: - return PostgresSqlStoreConfig.pip_packages() - else: - raise ValueError(f"Unknown SQL store type: {store_type}") - else: - return store_config.pip_packages() - - -def sqlstore_impl(reference: SqlStoreReference) -> SqlStore: - backend_name = reference.backend - - backend_config = _SQLSTORE_BACKENDS.get(backend_name) - if backend_config is None: - raise ValueError( - f"Unknown SQL store backend '{backend_name}'. Registered backends: {sorted(_SQLSTORE_BACKENDS)}" - ) - - existing = _SQLSTORE_INSTANCES.get(backend_name) - if existing: - return existing - - lock = _SQLSTORE_LOCKS.setdefault(backend_name, Lock()) - with lock: - existing = _SQLSTORE_INSTANCES.get(backend_name) - if existing: - return existing - - if isinstance(backend_config, SqliteSqlStoreConfig | PostgresSqlStoreConfig): - from .sqlalchemy_sqlstore import SqlAlchemySqlStoreImpl - - config = cast(SqliteSqlStoreConfig | PostgresSqlStoreConfig, backend_config).model_copy() - instance = SqlAlchemySqlStoreImpl(config) - _SQLSTORE_INSTANCES[backend_name] = instance - return instance - else: - raise ValueError(f"Unknown sqlstore type {backend_config.type}") - - -def register_sqlstore_backends(backends: dict[str, StorageBackendConfig]) -> None: - """Register the set of available SQL store backends for reference resolution.""" - global _SQLSTORE_BACKENDS - global _SQLSTORE_INSTANCES - - _SQLSTORE_BACKENDS.clear() - _SQLSTORE_INSTANCES.clear() - _SQLSTORE_LOCKS.clear() - for name, cfg in backends.items(): - _SQLSTORE_BACKENDS[name] = cfg +from llama_stack.core.storage.sqlstore.sqlstore import * # noqa: F401,F403 diff --git a/src/llama_stack_api/internal/__init__.py b/src/llama_stack_api/internal/__init__.py new file mode 100644 index 000000000..5be0ba9a9 --- /dev/null +++ b/src/llama_stack_api/internal/__init__.py @@ -0,0 +1,3 @@ +# Internal subpackage for shared interfaces that are not part of the public API. + +__all__: list[str] = [] diff --git a/src/llama_stack_api/internal/kvstore.py b/src/llama_stack_api/internal/kvstore.py new file mode 100644 index 000000000..a6d982261 --- /dev/null +++ b/src/llama_stack_api/internal/kvstore.py @@ -0,0 +1,26 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from datetime import datetime +from typing import Protocol + + +class KVStore(Protocol): + """Protocol for simple key/value storage backends.""" + + # TODO: make the value type bytes instead of str + async def set(self, key: str, value: str, expiration: datetime | None = None) -> None: ... + + async def get(self, key: str) -> str | None: ... + + async def delete(self, key: str) -> None: ... + + async def values_in_range(self, start_key: str, end_key: str) -> list[str]: ... + + async def keys_in_range(self, start_key: str, end_key: str) -> list[str]: ... + + +__all__ = ["KVStore"] diff --git a/src/llama_stack_api/internal/sqlstore.py b/src/llama_stack_api/internal/sqlstore.py new file mode 100644 index 000000000..ebb2d8ba2 --- /dev/null +++ b/src/llama_stack_api/internal/sqlstore.py @@ -0,0 +1,79 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from collections.abc import Mapping, Sequence +from enum import Enum +from typing import Any, Literal, Protocol + +from pydantic import BaseModel + +from llama_stack_api import PaginatedResponse + + +class ColumnType(Enum): + INTEGER = "INTEGER" + STRING = "STRING" + TEXT = "TEXT" + FLOAT = "FLOAT" + BOOLEAN = "BOOLEAN" + JSON = "JSON" + DATETIME = "DATETIME" + + +class ColumnDefinition(BaseModel): + type: ColumnType + primary_key: bool = False + nullable: bool = True + default: Any = None + + +class SqlStore(Protocol): + """Protocol for common SQL-store functionality.""" + + async def create_table(self, table: str, schema: Mapping[str, ColumnType | ColumnDefinition]) -> None: ... + + async def insert(self, table: str, data: Mapping[str, Any] | Sequence[Mapping[str, Any]]) -> None: ... + + async def upsert( + self, + table: str, + data: Mapping[str, Any], + conflict_columns: list[str], + update_columns: list[str] | None = None, + ) -> None: ... + + async def fetch_all( + self, + table: str, + where: Mapping[str, Any] | None = None, + where_sql: str | None = None, + limit: int | None = None, + order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None, + cursor: tuple[str, str] | None = None, + ) -> PaginatedResponse: ... + + async def fetch_one( + self, + table: str, + where: Mapping[str, Any] | None = None, + where_sql: str | None = None, + order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None, + ) -> dict[str, Any] | None: ... + + async def update(self, table: str, data: Mapping[str, Any], where: Mapping[str, Any]) -> None: ... + + async def delete(self, table: str, where: Mapping[str, Any]) -> None: ... + + async def add_column_if_not_exists( + self, + table: str, + column_name: str, + column_type: ColumnType, + nullable: bool = True, + ) -> None: ... + + +__all__ = ["ColumnDefinition", "ColumnType", "SqlStore"] diff --git a/tests/integration/files/test_files.py b/tests/integration/files/test_files.py index 1f19c88c5..e8004c95d 100644 --- a/tests/integration/files/test_files.py +++ b/tests/integration/files/test_files.py @@ -175,7 +175,7 @@ def test_expires_after_requests(openai_client): @pytest.mark.xfail(message="User isolation broken for current providers, must be fixed.") -@patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user") +@patch("llama_stack.core.storage.sqlstore.authorized_sqlstore.get_authenticated_user") def test_files_authentication_isolation(mock_get_authenticated_user, llama_stack_client): """Test that users can only access their own files.""" from llama_stack_client import NotFoundError @@ -275,7 +275,7 @@ def test_files_authentication_isolation(mock_get_authenticated_user, llama_stack raise e -@patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user") +@patch("llama_stack.core.storage.sqlstore.authorized_sqlstore.get_authenticated_user") def test_files_authentication_shared_attributes( mock_get_authenticated_user, llama_stack_client, provider_type_is_openai ): @@ -335,7 +335,7 @@ def test_files_authentication_shared_attributes( raise e -@patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user") +@patch("llama_stack.core.storage.sqlstore.authorized_sqlstore.get_authenticated_user") def test_files_authentication_anonymous_access( mock_get_authenticated_user, llama_stack_client, provider_type_is_openai ): diff --git a/tests/integration/providers/utils/sqlstore/test_authorized_sqlstore.py b/tests/integration/providers/utils/sqlstore/test_authorized_sqlstore.py index ad9115756..4f4f4a8dd 100644 --- a/tests/integration/providers/utils/sqlstore/test_authorized_sqlstore.py +++ b/tests/integration/providers/utils/sqlstore/test_authorized_sqlstore.py @@ -13,14 +13,14 @@ import pytest from llama_stack.core.access_control.access_control import default_policy from llama_stack.core.datatypes import User from llama_stack.core.storage.datatypes import SqlStoreReference -from llama_stack.providers.utils.sqlstore.api import ColumnType -from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore -from llama_stack.providers.utils.sqlstore.sqlstore import ( +from llama_stack.core.storage.sqlstore.authorized_sqlstore import AuthorizedSqlStore +from llama_stack.core.storage.sqlstore.sqlstore import ( PostgresSqlStoreConfig, SqliteSqlStoreConfig, register_sqlstore_backends, sqlstore_impl, ) +from llama_stack_api.internal.sqlstore import ColumnType def get_postgres_config(): @@ -96,7 +96,7 @@ async def cleanup_records(sql_store, table_name, record_ids): @pytest.mark.parametrize("backend_config", BACKEND_CONFIGS) -@patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user") +@patch("llama_stack.core.storage.sqlstore.authorized_sqlstore.get_authenticated_user") async def test_authorized_store_attributes(mock_get_authenticated_user, authorized_store, request): """Test that JSON column comparisons work correctly for both PostgreSQL and SQLite""" backend_name = request.node.callspec.id @@ -190,7 +190,7 @@ async def test_authorized_store_attributes(mock_get_authenticated_user, authoriz @pytest.mark.parametrize("backend_config", BACKEND_CONFIGS) -@patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user") +@patch("llama_stack.core.storage.sqlstore.authorized_sqlstore.get_authenticated_user") async def test_user_ownership_policy(mock_get_authenticated_user, authorized_store, request): """Test that 'user is owner' policies work correctly with record ownership""" from llama_stack.core.access_control.datatypes import AccessRule, Action, Scope diff --git a/tests/unit/conversations/test_conversations.py b/tests/unit/conversations/test_conversations.py index 95c54d379..e8286576b 100644 --- a/tests/unit/conversations/test_conversations.py +++ b/tests/unit/conversations/test_conversations.py @@ -23,7 +23,7 @@ from llama_stack.core.storage.datatypes import ( SqlStoreReference, StorageConfig, ) -from llama_stack.providers.utils.sqlstore.sqlstore import register_sqlstore_backends +from llama_stack.core.storage.sqlstore.sqlstore import register_sqlstore_backends from llama_stack_api import OpenAIResponseInputMessageContentText, OpenAIResponseMessage diff --git a/tests/unit/files/test_files.py b/tests/unit/files/test_files.py index 793f4edd3..546340c1f 100644 --- a/tests/unit/files/test_files.py +++ b/tests/unit/files/test_files.py @@ -13,7 +13,7 @@ from llama_stack.providers.inline.files.localfs import ( LocalfsFilesImpl, LocalfsFilesImplConfig, ) -from llama_stack.providers.utils.sqlstore.sqlstore import register_sqlstore_backends +from llama_stack.core.storage.sqlstore.sqlstore import register_sqlstore_backends from llama_stack_api import OpenAIFilePurpose, Order, ResourceNotFoundError diff --git a/tests/unit/fixtures.py b/tests/unit/fixtures.py index 443a1d371..18bfdf93d 100644 --- a/tests/unit/fixtures.py +++ b/tests/unit/fixtures.py @@ -7,8 +7,8 @@ import pytest from llama_stack.core.store.registry import CachedDiskDistributionRegistry, DiskDistributionRegistry -from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig -from llama_stack.providers.utils.kvstore.sqlite import SqliteKVStoreImpl +from llama_stack.core.storage.kvstore.config import SqliteKVStoreConfig +from llama_stack.core.storage.kvstore.sqlite import SqliteKVStoreImpl @pytest.fixture(scope="function") diff --git a/tests/unit/prompts/prompts/conftest.py b/tests/unit/prompts/prompts/conftest.py index c876f2041..8bfc1f03c 100644 --- a/tests/unit/prompts/prompts/conftest.py +++ b/tests/unit/prompts/prompts/conftest.py @@ -18,7 +18,7 @@ from llama_stack.core.storage.datatypes import ( SqlStoreReference, StorageConfig, ) -from llama_stack.providers.utils.kvstore import register_kvstore_backends +from llama_stack.core.storage.kvstore import register_kvstore_backends @pytest.fixture diff --git a/tests/unit/providers/agents/meta_reference/test_openai_responses.py b/tests/unit/providers/agents/meta_reference/test_openai_responses.py index 78f0d7cfd..f4b489339 100644 --- a/tests/unit/providers/agents/meta_reference/test_openai_responses.py +++ b/tests/unit/providers/agents/meta_reference/test_openai_responses.py @@ -24,7 +24,7 @@ from llama_stack.providers.utils.responses.responses_store import ( ResponsesStore, _OpenAIResponseObjectWithInputAndMessages, ) -from llama_stack.providers.utils.sqlstore.sqlstore import register_sqlstore_backends +from llama_stack.core.storage.sqlstore.sqlstore import register_sqlstore_backends from llama_stack_api.agents import Order from llama_stack_api.inference import ( OpenAIAssistantMessageParam, diff --git a/tests/unit/providers/batches/conftest.py b/tests/unit/providers/batches/conftest.py index d161bf976..f96656453 100644 --- a/tests/unit/providers/batches/conftest.py +++ b/tests/unit/providers/batches/conftest.py @@ -15,7 +15,7 @@ import pytest from llama_stack.core.storage.datatypes import KVStoreReference, SqliteKVStoreConfig from llama_stack.providers.inline.batches.reference.batches import ReferenceBatchesImpl from llama_stack.providers.inline.batches.reference.config import ReferenceBatchesImplConfig -from llama_stack.providers.utils.kvstore import kvstore_impl, register_kvstore_backends +from llama_stack.core.storage.kvstore import kvstore_impl, register_kvstore_backends @pytest.fixture diff --git a/tests/unit/providers/files/conftest.py b/tests/unit/providers/files/conftest.py index c64ecc3a3..5f16d7d49 100644 --- a/tests/unit/providers/files/conftest.py +++ b/tests/unit/providers/files/conftest.py @@ -10,7 +10,7 @@ from moto import mock_aws from llama_stack.core.storage.datatypes import SqliteSqlStoreConfig, SqlStoreReference from llama_stack.providers.remote.files.s3 import S3FilesImplConfig, get_adapter_impl -from llama_stack.providers.utils.sqlstore.sqlstore import register_sqlstore_backends +from llama_stack.core.storage.sqlstore.sqlstore import register_sqlstore_backends class MockUploadFile: diff --git a/tests/unit/providers/files/test_s3_files_auth.py b/tests/unit/providers/files/test_s3_files_auth.py index e113611bd..49b33fd7b 100644 --- a/tests/unit/providers/files/test_s3_files_auth.py +++ b/tests/unit/providers/files/test_s3_files_auth.py @@ -18,11 +18,11 @@ async def test_listing_hides_other_users_file(s3_provider, sample_text_file): user_a = User("user-a", {"roles": ["team-a"]}) user_b = User("user-b", {"roles": ["team-b"]}) - with patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user") as mock_get_user: + with patch("llama_stack.core.storage.sqlstore.authorized_sqlstore.get_authenticated_user") as mock_get_user: mock_get_user.return_value = user_a uploaded = await s3_provider.openai_upload_file(file=sample_text_file, purpose=OpenAIFilePurpose.ASSISTANTS) - with patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user") as mock_get_user: + with patch("llama_stack.core.storage.sqlstore.authorized_sqlstore.get_authenticated_user") as mock_get_user: mock_get_user.return_value = user_b listed = await s3_provider.openai_list_files() assert all(f.id != uploaded.id for f in listed.data) @@ -41,11 +41,11 @@ async def test_cannot_access_other_user_file(s3_provider, sample_text_file, op): user_a = User("user-a", {"roles": ["team-a"]}) user_b = User("user-b", {"roles": ["team-b"]}) - with patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user") as mock_get_user: + with patch("llama_stack.core.storage.sqlstore.authorized_sqlstore.get_authenticated_user") as mock_get_user: mock_get_user.return_value = user_a uploaded = await s3_provider.openai_upload_file(file=sample_text_file, purpose=OpenAIFilePurpose.ASSISTANTS) - with patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user") as mock_get_user: + with patch("llama_stack.core.storage.sqlstore.authorized_sqlstore.get_authenticated_user") as mock_get_user: mock_get_user.return_value = user_b with pytest.raises(ResourceNotFoundError): await op(s3_provider, uploaded.id) @@ -56,11 +56,11 @@ async def test_shared_role_allows_listing(s3_provider, sample_text_file): user_a = User("user-a", {"roles": ["shared-role"]}) user_b = User("user-b", {"roles": ["shared-role"]}) - with patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user") as mock_get_user: + with patch("llama_stack.core.storage.sqlstore.authorized_sqlstore.get_authenticated_user") as mock_get_user: mock_get_user.return_value = user_a uploaded = await s3_provider.openai_upload_file(file=sample_text_file, purpose=OpenAIFilePurpose.ASSISTANTS) - with patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user") as mock_get_user: + with patch("llama_stack.core.storage.sqlstore.authorized_sqlstore.get_authenticated_user") as mock_get_user: mock_get_user.return_value = user_b listed = await s3_provider.openai_list_files() assert any(f.id == uploaded.id for f in listed.data) @@ -79,10 +79,10 @@ async def test_shared_role_allows_access(s3_provider, sample_text_file, op): user_x = User("user-x", {"roles": ["shared-role"]}) user_y = User("user-y", {"roles": ["shared-role"]}) - with patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user") as mock_get_user: + with patch("llama_stack.core.storage.sqlstore.authorized_sqlstore.get_authenticated_user") as mock_get_user: mock_get_user.return_value = user_x uploaded = await s3_provider.openai_upload_file(file=sample_text_file, purpose=OpenAIFilePurpose.ASSISTANTS) - with patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user") as mock_get_user: + with patch("llama_stack.core.storage.sqlstore.authorized_sqlstore.get_authenticated_user") as mock_get_user: mock_get_user.return_value = user_y await op(s3_provider, uploaded.id) diff --git a/tests/unit/providers/vector_io/conftest.py b/tests/unit/providers/vector_io/conftest.py index 6408e25ab..0fd0267ff 100644 --- a/tests/unit/providers/vector_io/conftest.py +++ b/tests/unit/providers/vector_io/conftest.py @@ -17,7 +17,7 @@ from llama_stack.providers.inline.vector_io.sqlite_vec import SQLiteVectorIOConf from llama_stack.providers.inline.vector_io.sqlite_vec.sqlite_vec import SQLiteVecIndex, SQLiteVecVectorIOAdapter from llama_stack.providers.remote.vector_io.pgvector.config import PGVectorVectorIOConfig from llama_stack.providers.remote.vector_io.pgvector.pgvector import PGVectorIndex, PGVectorVectorIOAdapter -from llama_stack.providers.utils.kvstore import register_kvstore_backends +from llama_stack.core.storage.kvstore import register_kvstore_backends from llama_stack_api import Chunk, ChunkMetadata, QueryChunksResponse, VectorStore EMBEDDING_DIMENSION = 768 @@ -279,7 +279,7 @@ async def pgvector_vec_adapter(unique_kvstore_config, mock_inference_api, embedd ) as mock_check_version: mock_check_version.return_value = "0.5.1" - with patch("llama_stack.providers.utils.kvstore.kvstore_impl") as mock_kvstore_impl: + with patch("llama_stack.core.storage.kvstore.kvstore_impl") as mock_kvstore_impl: mock_kvstore = AsyncMock() mock_kvstore_impl.return_value = mock_kvstore diff --git a/tests/unit/registry/test_registry.py b/tests/unit/registry/test_registry.py index 1b5032782..cc3de1065 100644 --- a/tests/unit/registry/test_registry.py +++ b/tests/unit/registry/test_registry.py @@ -14,7 +14,7 @@ from llama_stack.core.store.registry import ( CachedDiskDistributionRegistry, DiskDistributionRegistry, ) -from llama_stack.providers.utils.kvstore import kvstore_impl, register_kvstore_backends +from llama_stack.core.storage.kvstore import kvstore_impl, register_kvstore_backends from llama_stack_api import Model, VectorStore diff --git a/tests/unit/server/test_quota.py b/tests/unit/server/test_quota.py index 0939414dd..cd8c38eed 100644 --- a/tests/unit/server/test_quota.py +++ b/tests/unit/server/test_quota.py @@ -15,7 +15,7 @@ from starlette.middleware.base import BaseHTTPMiddleware from llama_stack.core.datatypes import QuotaConfig, QuotaPeriod from llama_stack.core.server.quota import QuotaMiddleware from llama_stack.core.storage.datatypes import KVStoreReference, SqliteKVStoreConfig -from llama_stack.providers.utils.kvstore import register_kvstore_backends +from llama_stack.core.storage.kvstore import register_kvstore_backends @pytest.fixture diff --git a/tests/unit/server/test_resolver.py b/tests/unit/server/test_resolver.py index 8f8a61ea7..a1b03f630 100644 --- a/tests/unit/server/test_resolver.py +++ b/tests/unit/server/test_resolver.py @@ -24,8 +24,8 @@ from llama_stack.core.storage.datatypes import ( SqlStoreReference, StorageConfig, ) -from llama_stack.providers.utils.kvstore import register_kvstore_backends -from llama_stack.providers.utils.sqlstore.sqlstore import register_sqlstore_backends +from llama_stack.core.storage.kvstore import register_kvstore_backends +from llama_stack.core.storage.sqlstore.sqlstore import register_sqlstore_backends from llama_stack_api import Inference, InlineProviderSpec, ProviderSpec diff --git a/tests/unit/utils/inference/test_inference_store.py b/tests/unit/utils/inference/test_inference_store.py index bdcc529ce..70da0a95e 100644 --- a/tests/unit/utils/inference/test_inference_store.py +++ b/tests/unit/utils/inference/test_inference_store.py @@ -10,7 +10,7 @@ import pytest from llama_stack.core.storage.datatypes import InferenceStoreReference, SqliteSqlStoreConfig from llama_stack.providers.utils.inference.inference_store import InferenceStore -from llama_stack.providers.utils.sqlstore.sqlstore import register_sqlstore_backends +from llama_stack.core.storage.sqlstore.sqlstore import register_sqlstore_backends from llama_stack_api import ( OpenAIAssistantMessageParam, OpenAIChatCompletion, diff --git a/tests/unit/utils/kvstore/test_sqlite_memory.py b/tests/unit/utils/kvstore/test_sqlite_memory.py index a31377306..1aaf57b44 100644 --- a/tests/unit/utils/kvstore/test_sqlite_memory.py +++ b/tests/unit/utils/kvstore/test_sqlite_memory.py @@ -5,8 +5,8 @@ # the root directory of this source tree. -from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig -from llama_stack.providers.utils.kvstore.sqlite.sqlite import SqliteKVStoreImpl +from llama_stack.core.storage.kvstore.config import SqliteKVStoreConfig +from llama_stack.core.storage.kvstore.sqlite.sqlite import SqliteKVStoreImpl async def test_memory_kvstore_persistence_behavior(): diff --git a/tests/unit/utils/responses/test_responses_store.py b/tests/unit/utils/responses/test_responses_store.py index 8c108d9c1..284a3c350 100644 --- a/tests/unit/utils/responses/test_responses_store.py +++ b/tests/unit/utils/responses/test_responses_store.py @@ -12,7 +12,7 @@ import pytest from llama_stack.core.storage.datatypes import ResponsesStoreReference, SqliteSqlStoreConfig from llama_stack.providers.utils.responses.responses_store import ResponsesStore -from llama_stack.providers.utils.sqlstore.sqlstore import register_sqlstore_backends +from llama_stack.core.storage.sqlstore.sqlstore import register_sqlstore_backends from llama_stack_api import OpenAIMessageParam, OpenAIResponseInput, OpenAIResponseObject, OpenAIUserMessageParam, Order diff --git a/tests/unit/utils/sqlstore/test_sqlstore.py b/tests/unit/utils/sqlstore/test_sqlstore.py index d7ba0dc89..421e3b69d 100644 --- a/tests/unit/utils/sqlstore/test_sqlstore.py +++ b/tests/unit/utils/sqlstore/test_sqlstore.py @@ -9,9 +9,9 @@ from tempfile import TemporaryDirectory import pytest -from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType -from llama_stack.providers.utils.sqlstore.sqlalchemy_sqlstore import SqlAlchemySqlStoreImpl -from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig +from llama_stack.core.storage.sqlstore.sqlalchemy_sqlstore import SqlAlchemySqlStoreImpl +from llama_stack.core.storage.sqlstore.sqlstore import SqliteSqlStoreConfig +from llama_stack_api.internal.sqlstore import ColumnDefinition, ColumnType async def test_sqlite_sqlstore(): diff --git a/tests/unit/utils/test_authorized_sqlstore.py b/tests/unit/utils/test_authorized_sqlstore.py index d85e784a9..e9a6b511b 100644 --- a/tests/unit/utils/test_authorized_sqlstore.py +++ b/tests/unit/utils/test_authorized_sqlstore.py @@ -10,13 +10,13 @@ from unittest.mock import patch from llama_stack.core.access_control.access_control import default_policy, is_action_allowed from llama_stack.core.access_control.datatypes import Action from llama_stack.core.datatypes import User -from llama_stack.providers.utils.sqlstore.api import ColumnType -from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore, SqlRecord -from llama_stack.providers.utils.sqlstore.sqlalchemy_sqlstore import SqlAlchemySqlStoreImpl -from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig +from llama_stack.core.storage.sqlstore.authorized_sqlstore import AuthorizedSqlStore, SqlRecord +from llama_stack.core.storage.sqlstore.sqlalchemy_sqlstore import SqlAlchemySqlStoreImpl +from llama_stack.core.storage.sqlstore.sqlstore import SqliteSqlStoreConfig +from llama_stack_api.internal.sqlstore import ColumnType -@patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user") +@patch("llama_stack.core.storage.sqlstore.authorized_sqlstore.get_authenticated_user") async def test_authorized_fetch_with_where_sql_access_control(mock_get_authenticated_user): """Test that fetch_all works correctly with where_sql for access control""" with TemporaryDirectory() as tmp_dir: @@ -78,7 +78,7 @@ async def test_authorized_fetch_with_where_sql_access_control(mock_get_authentic assert row["title"] == "User Document" -@patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user") +@patch("llama_stack.core.storage.sqlstore.authorized_sqlstore.get_authenticated_user") async def test_sql_policy_consistency(mock_get_authenticated_user): """Test that SQL WHERE clause logic exactly matches is_action_allowed policy logic""" with TemporaryDirectory() as tmp_dir: @@ -164,7 +164,7 @@ async def test_sql_policy_consistency(mock_get_authenticated_user): ) -@patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user") +@patch("llama_stack.core.storage.sqlstore.authorized_sqlstore.get_authenticated_user") async def test_authorized_store_user_attribute_capture(mock_get_authenticated_user): """Test that user attributes are properly captured during insert""" with TemporaryDirectory() as tmp_dir: