mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 09:53:45 +00:00
refactor(storage): centralize kv/sql stores in core so providers share one API
This commit is contained in:
parent
7093978754
commit
dcb0c6b127
86 changed files with 1710 additions and 1616 deletions
|
|
@ -12,9 +12,8 @@ from pydantic import BaseModel, TypeAdapter
|
||||||
|
|
||||||
from llama_stack.core.datatypes import AccessRule, StackRunConfig
|
from llama_stack.core.datatypes import AccessRule, StackRunConfig
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType
|
from llama_stack.core.storage.sqlstore.authorized_sqlstore import AuthorizedSqlStore
|
||||||
from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore
|
from llama_stack.core.storage.sqlstore.sqlstore import sqlstore_impl
|
||||||
from llama_stack.providers.utils.sqlstore.sqlstore import sqlstore_impl
|
|
||||||
from llama_stack_api import (
|
from llama_stack_api import (
|
||||||
Conversation,
|
Conversation,
|
||||||
ConversationDeletedResource,
|
ConversationDeletedResource,
|
||||||
|
|
@ -25,6 +24,7 @@ from llama_stack_api import (
|
||||||
Conversations,
|
Conversations,
|
||||||
Metadata,
|
Metadata,
|
||||||
)
|
)
|
||||||
|
from llama_stack_api.internal.sqlstore import ColumnDefinition, ColumnType
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="openai_conversations")
|
logger = get_logger(name=__name__, category="openai_conversations")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,7 @@ from typing import Any
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from llama_stack.core.datatypes import StackRunConfig
|
from llama_stack.core.datatypes import StackRunConfig
|
||||||
from llama_stack.providers.utils.kvstore import KVStore, kvstore_impl
|
from llama_stack.core.storage.kvstore import KVStore, kvstore_impl
|
||||||
from llama_stack_api import ListPromptsResponse, Prompt, Prompts
|
from llama_stack_api import ListPromptsResponse, Prompt, Prompts
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -12,8 +12,8 @@ from starlette.types import ASGIApp, Receive, Scope, Send
|
||||||
|
|
||||||
from llama_stack.core.storage.datatypes import KVStoreReference, StorageBackendType
|
from llama_stack.core.storage.datatypes import KVStoreReference, StorageBackendType
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.utils.kvstore.api import KVStore
|
from llama_stack.core.storage.kvstore.kvstore import _KVSTORE_BACKENDS, kvstore_impl
|
||||||
from llama_stack.providers.utils.kvstore.kvstore import _KVSTORE_BACKENDS, kvstore_impl
|
from llama_stack_api.internal.kvstore import KVStore
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="core::server")
|
logger = get_logger(name=__name__, category="core::server")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -385,8 +385,8 @@ def _initialize_storage(run_config: StackRunConfig):
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown storage backend type: {type}")
|
raise ValueError(f"Unknown storage backend type: {type}")
|
||||||
|
|
||||||
from llama_stack.providers.utils.kvstore.kvstore import register_kvstore_backends
|
from llama_stack.core.storage.kvstore.kvstore import register_kvstore_backends
|
||||||
from llama_stack.providers.utils.sqlstore.sqlstore import register_sqlstore_backends
|
from llama_stack.core.storage.sqlstore.sqlstore import register_sqlstore_backends
|
||||||
|
|
||||||
register_kvstore_backends(kv_backends)
|
register_kvstore_backends(kv_backends)
|
||||||
register_sqlstore_backends(sql_backends)
|
register_sqlstore_backends(sql_backends)
|
||||||
|
|
|
||||||
7
src/llama_stack/core/storage/kvstore/__init__.py
Normal file
7
src/llama_stack/core/storage/kvstore/__init__.py
Normal file
|
|
@ -0,0 +1,7 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from .kvstore import * # noqa: F401, F403
|
||||||
9
src/llama_stack/core/storage/kvstore/api.py
Normal file
9
src/llama_stack/core/storage/kvstore/api.py
Normal file
|
|
@ -0,0 +1,9 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from llama_stack_api.internal.kvstore import KVStore
|
||||||
|
|
||||||
|
__all__ = ["KVStore"]
|
||||||
39
src/llama_stack/core/storage/kvstore/config.py
Normal file
39
src/llama_stack/core/storage/kvstore/config.py
Normal file
|
|
@ -0,0 +1,39 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Annotated
|
||||||
|
|
||||||
|
from pydantic import Field
|
||||||
|
|
||||||
|
from llama_stack.core.storage.datatypes import (
|
||||||
|
MongoDBKVStoreConfig,
|
||||||
|
PostgresKVStoreConfig,
|
||||||
|
RedisKVStoreConfig,
|
||||||
|
SqliteKVStoreConfig,
|
||||||
|
StorageBackendType,
|
||||||
|
)
|
||||||
|
|
||||||
|
KVStoreConfig = Annotated[
|
||||||
|
RedisKVStoreConfig | SqliteKVStoreConfig | PostgresKVStoreConfig | MongoDBKVStoreConfig, Field(discriminator="type")
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def get_pip_packages(store_config: dict | KVStoreConfig) -> list[str]:
|
||||||
|
"""Get pip packages for KV store config, handling both dict and object cases."""
|
||||||
|
if isinstance(store_config, dict):
|
||||||
|
store_type = store_config.get("type")
|
||||||
|
if store_type == StorageBackendType.KV_SQLITE.value:
|
||||||
|
return SqliteKVStoreConfig.pip_packages()
|
||||||
|
elif store_type == StorageBackendType.KV_POSTGRES.value:
|
||||||
|
return PostgresKVStoreConfig.pip_packages()
|
||||||
|
elif store_type == StorageBackendType.KV_REDIS.value:
|
||||||
|
return RedisKVStoreConfig.pip_packages()
|
||||||
|
elif store_type == StorageBackendType.KV_MONGODB.value:
|
||||||
|
return MongoDBKVStoreConfig.pip_packages()
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown KV store type: {store_type}")
|
||||||
|
else:
|
||||||
|
return store_config.pip_packages()
|
||||||
118
src/llama_stack/core/storage/kvstore/kvstore.py
Normal file
118
src/llama_stack/core/storage/kvstore/kvstore.py
Normal file
|
|
@ -0,0 +1,118 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
|
from llama_stack.core.storage.datatypes import KVStoreReference, StorageBackendConfig, StorageBackendType
|
||||||
|
|
||||||
|
from .api import KVStore
|
||||||
|
from .config import KVStoreConfig
|
||||||
|
|
||||||
|
|
||||||
|
def kvstore_dependencies():
|
||||||
|
"""
|
||||||
|
Returns all possible kvstore dependencies for registry/provider specifications.
|
||||||
|
|
||||||
|
NOTE: For specific kvstore implementations, use config.pip_packages instead.
|
||||||
|
This function returns the union of all dependencies for cases where the specific
|
||||||
|
kvstore type is not known at declaration time (e.g., provider registries).
|
||||||
|
"""
|
||||||
|
return ["aiosqlite", "psycopg2-binary", "redis", "pymongo"]
|
||||||
|
|
||||||
|
|
||||||
|
class InmemoryKVStoreImpl(KVStore):
|
||||||
|
def __init__(self):
|
||||||
|
self._store = {}
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def get(self, key: str) -> str | None:
|
||||||
|
return self._store.get(key)
|
||||||
|
|
||||||
|
async def set(self, key: str, value: str) -> None:
|
||||||
|
self._store[key] = value
|
||||||
|
|
||||||
|
async def values_in_range(self, start_key: str, end_key: str) -> list[str]:
|
||||||
|
return [self._store[key] for key in self._store.keys() if key >= start_key and key < end_key]
|
||||||
|
|
||||||
|
async def keys_in_range(self, start_key: str, end_key: str) -> list[str]:
|
||||||
|
"""Get all keys in the given range."""
|
||||||
|
return [key for key in self._store.keys() if key >= start_key and key < end_key]
|
||||||
|
|
||||||
|
async def delete(self, key: str) -> None:
|
||||||
|
del self._store[key]
|
||||||
|
|
||||||
|
|
||||||
|
_KVSTORE_BACKENDS: dict[str, KVStoreConfig] = {}
|
||||||
|
_KVSTORE_INSTANCES: dict[tuple[str, str], KVStore] = {}
|
||||||
|
_KVSTORE_LOCKS: defaultdict[tuple[str, str], asyncio.Lock] = defaultdict(asyncio.Lock)
|
||||||
|
|
||||||
|
|
||||||
|
def register_kvstore_backends(backends: dict[str, StorageBackendConfig]) -> None:
|
||||||
|
"""Register the set of available KV store backends for reference resolution."""
|
||||||
|
global _KVSTORE_BACKENDS
|
||||||
|
global _KVSTORE_INSTANCES
|
||||||
|
global _KVSTORE_LOCKS
|
||||||
|
|
||||||
|
_KVSTORE_BACKENDS.clear()
|
||||||
|
_KVSTORE_INSTANCES.clear()
|
||||||
|
_KVSTORE_LOCKS.clear()
|
||||||
|
for name, cfg in backends.items():
|
||||||
|
_KVSTORE_BACKENDS[name] = cfg
|
||||||
|
|
||||||
|
|
||||||
|
async def kvstore_impl(reference: KVStoreReference) -> KVStore:
|
||||||
|
backend_name = reference.backend
|
||||||
|
cache_key = (backend_name, reference.namespace)
|
||||||
|
|
||||||
|
existing = _KVSTORE_INSTANCES.get(cache_key)
|
||||||
|
if existing:
|
||||||
|
return existing
|
||||||
|
|
||||||
|
backend_config = _KVSTORE_BACKENDS.get(backend_name)
|
||||||
|
if backend_config is None:
|
||||||
|
raise ValueError(f"Unknown KVStore backend '{backend_name}'. Registered backends: {sorted(_KVSTORE_BACKENDS)}")
|
||||||
|
|
||||||
|
lock = _KVSTORE_LOCKS[cache_key]
|
||||||
|
async with lock:
|
||||||
|
existing = _KVSTORE_INSTANCES.get(cache_key)
|
||||||
|
if existing:
|
||||||
|
return existing
|
||||||
|
|
||||||
|
config = backend_config.model_copy()
|
||||||
|
config.namespace = reference.namespace
|
||||||
|
|
||||||
|
if config.type == StorageBackendType.KV_REDIS.value:
|
||||||
|
from .redis import RedisKVStoreImpl
|
||||||
|
|
||||||
|
impl = RedisKVStoreImpl(config)
|
||||||
|
elif config.type == StorageBackendType.KV_SQLITE.value:
|
||||||
|
from .sqlite import SqliteKVStoreImpl
|
||||||
|
|
||||||
|
impl = SqliteKVStoreImpl(config)
|
||||||
|
elif config.type == StorageBackendType.KV_POSTGRES.value:
|
||||||
|
from .postgres import PostgresKVStoreImpl
|
||||||
|
|
||||||
|
impl = PostgresKVStoreImpl(config)
|
||||||
|
elif config.type == StorageBackendType.KV_MONGODB.value:
|
||||||
|
from .mongodb import MongoDBKVStoreImpl
|
||||||
|
|
||||||
|
impl = MongoDBKVStoreImpl(config)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown kvstore type {config.type}")
|
||||||
|
|
||||||
|
await impl.initialize()
|
||||||
|
_KVSTORE_INSTANCES[cache_key] = impl
|
||||||
|
return impl
|
||||||
9
src/llama_stack/core/storage/kvstore/mongodb/__init__.py
Normal file
9
src/llama_stack/core/storage/kvstore/mongodb/__init__.py
Normal file
|
|
@ -0,0 +1,9 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from .mongodb import MongoDBKVStoreImpl
|
||||||
|
|
||||||
|
__all__ = ["MongoDBKVStoreImpl"]
|
||||||
85
src/llama_stack/core/storage/kvstore/mongodb/mongodb.py
Normal file
85
src/llama_stack/core/storage/kvstore/mongodb/mongodb.py
Normal file
|
|
@ -0,0 +1,85 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from pymongo import AsyncMongoClient
|
||||||
|
from pymongo.asynchronous.collection import AsyncCollection
|
||||||
|
|
||||||
|
from llama_stack.log import get_logger
|
||||||
|
from llama_stack.core.storage.kvstore import KVStore
|
||||||
|
|
||||||
|
from ..config import MongoDBKVStoreConfig
|
||||||
|
|
||||||
|
log = get_logger(name=__name__, category="providers::utils")
|
||||||
|
|
||||||
|
|
||||||
|
class MongoDBKVStoreImpl(KVStore):
|
||||||
|
def __init__(self, config: MongoDBKVStoreConfig):
|
||||||
|
self.config = config
|
||||||
|
self.conn: AsyncMongoClient | None = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def collection(self) -> AsyncCollection:
|
||||||
|
if self.conn is None:
|
||||||
|
raise RuntimeError("MongoDB connection is not initialized")
|
||||||
|
return self.conn[self.config.db][self.config.collection_name]
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
try:
|
||||||
|
# Pass parameters explicitly to satisfy mypy - AsyncMongoClient doesn't accept **dict
|
||||||
|
self.conn = AsyncMongoClient(
|
||||||
|
host=self.config.host if self.config.host is not None else None,
|
||||||
|
port=self.config.port if self.config.port is not None else None,
|
||||||
|
username=self.config.user if self.config.user is not None else None,
|
||||||
|
password=self.config.password if self.config.password is not None else None,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
log.exception("Could not connect to MongoDB database server")
|
||||||
|
raise RuntimeError("Could not connect to MongoDB database server") from e
|
||||||
|
|
||||||
|
def _namespaced_key(self, key: str) -> str:
|
||||||
|
if not self.config.namespace:
|
||||||
|
return key
|
||||||
|
return f"{self.config.namespace}:{key}"
|
||||||
|
|
||||||
|
async def set(self, key: str, value: str, expiration: datetime | None = None) -> None:
|
||||||
|
key = self._namespaced_key(key)
|
||||||
|
update_query = {"$set": {"value": value, "expiration": expiration}}
|
||||||
|
await self.collection.update_one({"key": key}, update_query, upsert=True)
|
||||||
|
|
||||||
|
async def get(self, key: str) -> str | None:
|
||||||
|
key = self._namespaced_key(key)
|
||||||
|
query = {"key": key}
|
||||||
|
result = await self.collection.find_one(query, {"value": 1, "_id": 0})
|
||||||
|
return result["value"] if result else None
|
||||||
|
|
||||||
|
async def delete(self, key: str) -> None:
|
||||||
|
key = self._namespaced_key(key)
|
||||||
|
await self.collection.delete_one({"key": key})
|
||||||
|
|
||||||
|
async def values_in_range(self, start_key: str, end_key: str) -> list[str]:
|
||||||
|
start_key = self._namespaced_key(start_key)
|
||||||
|
end_key = self._namespaced_key(end_key)
|
||||||
|
query = {
|
||||||
|
"key": {"$gte": start_key, "$lt": end_key},
|
||||||
|
}
|
||||||
|
cursor = self.collection.find(query, {"value": 1, "_id": 0}).sort("key", 1)
|
||||||
|
result = []
|
||||||
|
async for doc in cursor:
|
||||||
|
result.append(doc["value"])
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def keys_in_range(self, start_key: str, end_key: str) -> list[str]:
|
||||||
|
start_key = self._namespaced_key(start_key)
|
||||||
|
end_key = self._namespaced_key(end_key)
|
||||||
|
query = {"key": {"$gte": start_key, "$lt": end_key}}
|
||||||
|
cursor = self.collection.find(query, {"key": 1, "_id": 0}).sort("key", 1)
|
||||||
|
# AsyncCursor requires async iteration
|
||||||
|
result = []
|
||||||
|
async for doc in cursor:
|
||||||
|
result.append(doc["key"])
|
||||||
|
return result
|
||||||
|
|
@ -0,0 +1,7 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from .postgres import PostgresKVStoreImpl # noqa: F401 F403
|
||||||
114
src/llama_stack/core/storage/kvstore/postgres/postgres.py
Normal file
114
src/llama_stack/core/storage/kvstore/postgres/postgres.py
Normal file
|
|
@ -0,0 +1,114 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
import psycopg2
|
||||||
|
from psycopg2.extras import DictCursor
|
||||||
|
|
||||||
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
|
from ..api import KVStore
|
||||||
|
from ..config import PostgresKVStoreConfig
|
||||||
|
|
||||||
|
log = get_logger(name=__name__, category="providers::utils")
|
||||||
|
|
||||||
|
|
||||||
|
class PostgresKVStoreImpl(KVStore):
|
||||||
|
def __init__(self, config: PostgresKVStoreConfig):
|
||||||
|
self.config = config
|
||||||
|
self.conn = None
|
||||||
|
self.cursor = None
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
try:
|
||||||
|
self.conn = psycopg2.connect(
|
||||||
|
host=self.config.host,
|
||||||
|
port=self.config.port,
|
||||||
|
database=self.config.db,
|
||||||
|
user=self.config.user,
|
||||||
|
password=self.config.password,
|
||||||
|
sslmode=self.config.ssl_mode,
|
||||||
|
sslrootcert=self.config.ca_cert_path,
|
||||||
|
)
|
||||||
|
self.conn.autocommit = True
|
||||||
|
self.cursor = self.conn.cursor(cursor_factory=DictCursor)
|
||||||
|
|
||||||
|
# Create table if it doesn't exist
|
||||||
|
self.cursor.execute(
|
||||||
|
f"""
|
||||||
|
CREATE TABLE IF NOT EXISTS {self.config.table_name} (
|
||||||
|
key TEXT PRIMARY KEY,
|
||||||
|
value TEXT,
|
||||||
|
expiration TIMESTAMP
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
log.exception("Could not connect to PostgreSQL database server")
|
||||||
|
raise RuntimeError("Could not connect to PostgreSQL database server") from e
|
||||||
|
|
||||||
|
def _namespaced_key(self, key: str) -> str:
|
||||||
|
if not self.config.namespace:
|
||||||
|
return key
|
||||||
|
return f"{self.config.namespace}:{key}"
|
||||||
|
|
||||||
|
async def set(self, key: str, value: str, expiration: datetime | None = None) -> None:
|
||||||
|
key = self._namespaced_key(key)
|
||||||
|
self.cursor.execute(
|
||||||
|
f"""
|
||||||
|
INSERT INTO {self.config.table_name} (key, value, expiration)
|
||||||
|
VALUES (%s, %s, %s)
|
||||||
|
ON CONFLICT (key) DO UPDATE
|
||||||
|
SET value = EXCLUDED.value, expiration = EXCLUDED.expiration
|
||||||
|
""",
|
||||||
|
(key, value, expiration),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get(self, key: str) -> str | None:
|
||||||
|
key = self._namespaced_key(key)
|
||||||
|
self.cursor.execute(
|
||||||
|
f"""
|
||||||
|
SELECT value FROM {self.config.table_name}
|
||||||
|
WHERE key = %s
|
||||||
|
AND (expiration IS NULL OR expiration > NOW())
|
||||||
|
""",
|
||||||
|
(key,),
|
||||||
|
)
|
||||||
|
result = self.cursor.fetchone()
|
||||||
|
return result[0] if result else None
|
||||||
|
|
||||||
|
async def delete(self, key: str) -> None:
|
||||||
|
key = self._namespaced_key(key)
|
||||||
|
self.cursor.execute(
|
||||||
|
f"DELETE FROM {self.config.table_name} WHERE key = %s",
|
||||||
|
(key,),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def values_in_range(self, start_key: str, end_key: str) -> list[str]:
|
||||||
|
start_key = self._namespaced_key(start_key)
|
||||||
|
end_key = self._namespaced_key(end_key)
|
||||||
|
|
||||||
|
self.cursor.execute(
|
||||||
|
f"""
|
||||||
|
SELECT value FROM {self.config.table_name}
|
||||||
|
WHERE key >= %s AND key < %s
|
||||||
|
AND (expiration IS NULL OR expiration > NOW())
|
||||||
|
ORDER BY key
|
||||||
|
""",
|
||||||
|
(start_key, end_key),
|
||||||
|
)
|
||||||
|
return [row[0] for row in self.cursor.fetchall()]
|
||||||
|
|
||||||
|
async def keys_in_range(self, start_key: str, end_key: str) -> list[str]:
|
||||||
|
start_key = self._namespaced_key(start_key)
|
||||||
|
end_key = self._namespaced_key(end_key)
|
||||||
|
|
||||||
|
self.cursor.execute(
|
||||||
|
f"SELECT key FROM {self.config.table_name} WHERE key >= %s AND key < %s",
|
||||||
|
(start_key, end_key),
|
||||||
|
)
|
||||||
|
return [row[0] for row in self.cursor.fetchall()]
|
||||||
7
src/llama_stack/core/storage/kvstore/redis/__init__.py
Normal file
7
src/llama_stack/core/storage/kvstore/redis/__init__.py
Normal file
|
|
@ -0,0 +1,7 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from .redis import RedisKVStoreImpl # noqa: F401
|
||||||
76
src/llama_stack/core/storage/kvstore/redis/redis.py
Normal file
76
src/llama_stack/core/storage/kvstore/redis/redis.py
Normal file
|
|
@ -0,0 +1,76 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from redis.asyncio import Redis
|
||||||
|
|
||||||
|
from ..api import KVStore
|
||||||
|
from ..config import RedisKVStoreConfig
|
||||||
|
|
||||||
|
|
||||||
|
class RedisKVStoreImpl(KVStore):
|
||||||
|
def __init__(self, config: RedisKVStoreConfig):
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
self.redis = Redis.from_url(self.config.url)
|
||||||
|
|
||||||
|
def _namespaced_key(self, key: str) -> str:
|
||||||
|
if not self.config.namespace:
|
||||||
|
return key
|
||||||
|
return f"{self.config.namespace}:{key}"
|
||||||
|
|
||||||
|
async def set(self, key: str, value: str, expiration: datetime | None = None) -> None:
|
||||||
|
key = self._namespaced_key(key)
|
||||||
|
await self.redis.set(key, value)
|
||||||
|
if expiration:
|
||||||
|
await self.redis.expireat(key, expiration)
|
||||||
|
|
||||||
|
async def get(self, key: str) -> str | None:
|
||||||
|
key = self._namespaced_key(key)
|
||||||
|
value = await self.redis.get(key)
|
||||||
|
if value is None:
|
||||||
|
return None
|
||||||
|
await self.redis.ttl(key)
|
||||||
|
return value
|
||||||
|
|
||||||
|
async def delete(self, key: str) -> None:
|
||||||
|
key = self._namespaced_key(key)
|
||||||
|
await self.redis.delete(key)
|
||||||
|
|
||||||
|
async def values_in_range(self, start_key: str, end_key: str) -> list[str]:
|
||||||
|
start_key = self._namespaced_key(start_key)
|
||||||
|
end_key = self._namespaced_key(end_key)
|
||||||
|
cursor = 0
|
||||||
|
pattern = start_key + "*" # Match all keys starting with start_key prefix
|
||||||
|
matching_keys = []
|
||||||
|
while True:
|
||||||
|
cursor, keys = await self.redis.scan(cursor, match=pattern, count=1000)
|
||||||
|
|
||||||
|
for key in keys:
|
||||||
|
key_str = key.decode("utf-8") if isinstance(key, bytes) else key
|
||||||
|
if start_key <= key_str <= end_key:
|
||||||
|
matching_keys.append(key)
|
||||||
|
|
||||||
|
if cursor == 0:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Then fetch all values in a single MGET call
|
||||||
|
if matching_keys:
|
||||||
|
values = await self.redis.mget(matching_keys)
|
||||||
|
return [
|
||||||
|
value.decode("utf-8") if isinstance(value, bytes) else value for value in values if value is not None
|
||||||
|
]
|
||||||
|
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def keys_in_range(self, start_key: str, end_key: str) -> list[str]:
|
||||||
|
"""Get all keys in the given range."""
|
||||||
|
matching_keys = await self.redis.zrangebylex(self.namespace, f"[{start_key}", f"[{end_key}")
|
||||||
|
if not matching_keys:
|
||||||
|
return []
|
||||||
|
return [k.decode("utf-8") for k in matching_keys]
|
||||||
7
src/llama_stack/core/storage/kvstore/sqlite/__init__.py
Normal file
7
src/llama_stack/core/storage/kvstore/sqlite/__init__.py
Normal file
|
|
@ -0,0 +1,7 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from .sqlite import SqliteKVStoreImpl # noqa: F401
|
||||||
20
src/llama_stack/core/storage/kvstore/sqlite/config.py
Normal file
20
src/llama_stack/core/storage/kvstore/sqlite/config.py
Normal file
|
|
@ -0,0 +1,20 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from llama_stack_api import json_schema_type
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class SqliteControlPlaneConfig(BaseModel):
|
||||||
|
db_path: str = Field(
|
||||||
|
description="File path for the sqlite database",
|
||||||
|
)
|
||||||
|
table_name: str = Field(
|
||||||
|
default="llamastack_control_plane",
|
||||||
|
description="Table into which all the keys will be placed",
|
||||||
|
)
|
||||||
174
src/llama_stack/core/storage/kvstore/sqlite/sqlite.py
Normal file
174
src/llama_stack/core/storage/kvstore/sqlite/sqlite.py
Normal file
|
|
@ -0,0 +1,174 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import os
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
import aiosqlite
|
||||||
|
|
||||||
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
|
from ..api import KVStore
|
||||||
|
from ..config import SqliteKVStoreConfig
|
||||||
|
|
||||||
|
logger = get_logger(name=__name__, category="providers::utils")
|
||||||
|
|
||||||
|
|
||||||
|
class SqliteKVStoreImpl(KVStore):
|
||||||
|
def __init__(self, config: SqliteKVStoreConfig):
|
||||||
|
self.db_path = config.db_path
|
||||||
|
self.table_name = "kvstore"
|
||||||
|
self._conn: aiosqlite.Connection | None = None
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return f"SqliteKVStoreImpl(db_path={self.db_path}, table_name={self.table_name})"
|
||||||
|
|
||||||
|
def _is_memory_db(self) -> bool:
|
||||||
|
"""Check if this is an in-memory database."""
|
||||||
|
return self.db_path == ":memory:" or "mode=memory" in self.db_path
|
||||||
|
|
||||||
|
async def initialize(self):
|
||||||
|
# Skip directory creation for in-memory databases and file: URIs
|
||||||
|
if not self._is_memory_db() and not self.db_path.startswith("file:"):
|
||||||
|
db_dir = os.path.dirname(self.db_path)
|
||||||
|
if db_dir: # Only create if there's a directory component
|
||||||
|
os.makedirs(db_dir, exist_ok=True)
|
||||||
|
|
||||||
|
# Only use persistent connection for in-memory databases
|
||||||
|
# File-based databases use connection-per-operation to avoid hangs
|
||||||
|
if self._is_memory_db():
|
||||||
|
self._conn = await aiosqlite.connect(self.db_path)
|
||||||
|
await self._conn.execute(
|
||||||
|
f"""
|
||||||
|
CREATE TABLE IF NOT EXISTS {self.table_name} (
|
||||||
|
key TEXT PRIMARY KEY,
|
||||||
|
value TEXT,
|
||||||
|
expiration TIMESTAMP
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
await self._conn.commit()
|
||||||
|
else:
|
||||||
|
# For file-based databases, just create the table
|
||||||
|
async with aiosqlite.connect(self.db_path) as db:
|
||||||
|
await db.execute(
|
||||||
|
f"""
|
||||||
|
CREATE TABLE IF NOT EXISTS {self.table_name} (
|
||||||
|
key TEXT PRIMARY KEY,
|
||||||
|
value TEXT,
|
||||||
|
expiration TIMESTAMP
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
async def shutdown(self):
|
||||||
|
"""Close the persistent connection (only for in-memory databases)."""
|
||||||
|
if self._conn:
|
||||||
|
await self._conn.close()
|
||||||
|
self._conn = None
|
||||||
|
|
||||||
|
async def set(self, key: str, value: str, expiration: datetime | None = None) -> None:
|
||||||
|
if self._conn:
|
||||||
|
# In-memory database with persistent connection
|
||||||
|
await self._conn.execute(
|
||||||
|
f"INSERT OR REPLACE INTO {self.table_name} (key, value, expiration) VALUES (?, ?, ?)",
|
||||||
|
(key, value, expiration),
|
||||||
|
)
|
||||||
|
await self._conn.commit()
|
||||||
|
else:
|
||||||
|
# File-based database with connection per operation
|
||||||
|
async with aiosqlite.connect(self.db_path) as db:
|
||||||
|
await db.execute(
|
||||||
|
f"INSERT OR REPLACE INTO {self.table_name} (key, value, expiration) VALUES (?, ?, ?)",
|
||||||
|
(key, value, expiration),
|
||||||
|
)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
async def get(self, key: str) -> str | None:
|
||||||
|
if self._conn:
|
||||||
|
# In-memory database with persistent connection
|
||||||
|
async with self._conn.execute(
|
||||||
|
f"SELECT value, expiration FROM {self.table_name} WHERE key = ?", (key,)
|
||||||
|
) as cursor:
|
||||||
|
row = await cursor.fetchone()
|
||||||
|
if row is None:
|
||||||
|
return None
|
||||||
|
value, expiration = row
|
||||||
|
if not isinstance(value, str):
|
||||||
|
logger.warning(f"Expected string value for key {key}, got {type(value)}, returning None")
|
||||||
|
return None
|
||||||
|
return value
|
||||||
|
else:
|
||||||
|
# File-based database with connection per operation
|
||||||
|
async with aiosqlite.connect(self.db_path) as db:
|
||||||
|
async with db.execute(
|
||||||
|
f"SELECT value, expiration FROM {self.table_name} WHERE key = ?", (key,)
|
||||||
|
) as cursor:
|
||||||
|
row = await cursor.fetchone()
|
||||||
|
if row is None:
|
||||||
|
return None
|
||||||
|
value, expiration = row
|
||||||
|
if not isinstance(value, str):
|
||||||
|
logger.warning(f"Expected string value for key {key}, got {type(value)}, returning None")
|
||||||
|
return None
|
||||||
|
return value
|
||||||
|
|
||||||
|
async def delete(self, key: str) -> None:
|
||||||
|
if self._conn:
|
||||||
|
# In-memory database with persistent connection
|
||||||
|
await self._conn.execute(f"DELETE FROM {self.table_name} WHERE key = ?", (key,))
|
||||||
|
await self._conn.commit()
|
||||||
|
else:
|
||||||
|
# File-based database with connection per operation
|
||||||
|
async with aiosqlite.connect(self.db_path) as db:
|
||||||
|
await db.execute(f"DELETE FROM {self.table_name} WHERE key = ?", (key,))
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
async def values_in_range(self, start_key: str, end_key: str) -> list[str]:
|
||||||
|
if self._conn:
|
||||||
|
# In-memory database with persistent connection
|
||||||
|
async with self._conn.execute(
|
||||||
|
f"SELECT key, value, expiration FROM {self.table_name} WHERE key >= ? AND key <= ?",
|
||||||
|
(start_key, end_key),
|
||||||
|
) as cursor:
|
||||||
|
result = []
|
||||||
|
async for row in cursor:
|
||||||
|
_, value, _ = row
|
||||||
|
result.append(value)
|
||||||
|
return result
|
||||||
|
else:
|
||||||
|
# File-based database with connection per operation
|
||||||
|
async with aiosqlite.connect(self.db_path) as db:
|
||||||
|
async with db.execute(
|
||||||
|
f"SELECT key, value, expiration FROM {self.table_name} WHERE key >= ? AND key <= ?",
|
||||||
|
(start_key, end_key),
|
||||||
|
) as cursor:
|
||||||
|
result = []
|
||||||
|
async for row in cursor:
|
||||||
|
_, value, _ = row
|
||||||
|
result.append(value)
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def keys_in_range(self, start_key: str, end_key: str) -> list[str]:
|
||||||
|
"""Get all keys in the given range."""
|
||||||
|
if self._conn:
|
||||||
|
# In-memory database with persistent connection
|
||||||
|
cursor = await self._conn.execute(
|
||||||
|
f"SELECT key FROM {self.table_name} WHERE key >= ? AND key <= ?",
|
||||||
|
(start_key, end_key),
|
||||||
|
)
|
||||||
|
rows = await cursor.fetchall()
|
||||||
|
return [row[0] for row in rows]
|
||||||
|
else:
|
||||||
|
# File-based database with connection per operation
|
||||||
|
async with aiosqlite.connect(self.db_path) as db:
|
||||||
|
cursor = await db.execute(
|
||||||
|
f"SELECT key FROM {self.table_name} WHERE key >= ? AND key <= ?",
|
||||||
|
(start_key, end_key),
|
||||||
|
)
|
||||||
|
rows = await cursor.fetchall()
|
||||||
|
return [row[0] for row in rows]
|
||||||
7
src/llama_stack/core/storage/sqlstore/__init__.py
Normal file
7
src/llama_stack/core/storage/sqlstore/__init__.py
Normal file
|
|
@ -0,0 +1,7 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from .sqlstore import * # noqa: F401,F403
|
||||||
10
src/llama_stack/core/storage/sqlstore/api.py
Normal file
10
src/llama_stack/core/storage/sqlstore/api.py
Normal file
|
|
@ -0,0 +1,10 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from llama_stack_api import PaginatedResponse
|
||||||
|
from llama_stack_api.internal.sqlstore import ColumnDefinition, ColumnType, SqlStore
|
||||||
|
|
||||||
|
__all__ = ["ColumnDefinition", "ColumnType", "SqlStore", "PaginatedResponse"]
|
||||||
336
src/llama_stack/core/storage/sqlstore/authorized_sqlstore.py
Normal file
336
src/llama_stack/core/storage/sqlstore/authorized_sqlstore.py
Normal file
|
|
@ -0,0 +1,336 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from collections.abc import Mapping, Sequence
|
||||||
|
from typing import Any, Literal
|
||||||
|
|
||||||
|
from llama_stack.core.access_control.access_control import default_policy, is_action_allowed
|
||||||
|
from llama_stack.core.access_control.conditions import ProtectedResource
|
||||||
|
from llama_stack.core.access_control.datatypes import AccessRule, Action, Scope
|
||||||
|
from llama_stack.core.datatypes import User
|
||||||
|
from llama_stack.core.request_headers import get_authenticated_user
|
||||||
|
from llama_stack.core.storage.datatypes import StorageBackendType
|
||||||
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
|
from .api import ColumnDefinition, ColumnType, PaginatedResponse, SqlStore
|
||||||
|
|
||||||
|
logger = get_logger(name=__name__, category="providers::utils")
|
||||||
|
|
||||||
|
# Hardcoded copy of the default policy that our SQL filtering implements
|
||||||
|
# WARNING: If default_policy() changes, this constant must be updated accordingly
|
||||||
|
# or SQL filtering will fall back to conservative mode (safe but less performant)
|
||||||
|
#
|
||||||
|
# This policy represents: "Permit all actions when user is in owners list for ALL attribute categories"
|
||||||
|
# The corresponding SQL logic is implemented in _build_default_policy_where_clause():
|
||||||
|
# - Public records (no access_attributes) are always accessible
|
||||||
|
# - Records with access_attributes require user to match ALL categories that exist in the resource
|
||||||
|
# - Missing categories in the resource are treated as "no restriction" (allow)
|
||||||
|
# - Within each category, user needs ANY matching value (OR logic)
|
||||||
|
# - Between categories, user needs ALL categories to match (AND logic)
|
||||||
|
SQL_OPTIMIZED_POLICY = [
|
||||||
|
AccessRule(
|
||||||
|
permit=Scope(actions=list(Action)),
|
||||||
|
when=["user in owners roles", "user in owners teams", "user in owners projects", "user in owners namespaces"],
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _enhance_item_with_access_control(item: Mapping[str, Any], current_user: User | None) -> Mapping[str, Any]:
|
||||||
|
"""Add access control attributes to a data item."""
|
||||||
|
enhanced = dict(item)
|
||||||
|
if current_user:
|
||||||
|
enhanced["owner_principal"] = current_user.principal
|
||||||
|
enhanced["access_attributes"] = current_user.attributes
|
||||||
|
else:
|
||||||
|
# IMPORTANT: Use empty string and null value (not None) to match public access filter
|
||||||
|
# The public access filter in _get_public_access_conditions() expects:
|
||||||
|
# - owner_principal = '' (empty string)
|
||||||
|
# - access_attributes = null (JSON null, which serializes to the string 'null')
|
||||||
|
# Setting them to None (SQL NULL) will cause rows to be filtered out on read.
|
||||||
|
enhanced["owner_principal"] = ""
|
||||||
|
enhanced["access_attributes"] = None # Pydantic/JSON will serialize this as JSON null
|
||||||
|
return enhanced
|
||||||
|
|
||||||
|
|
||||||
|
class SqlRecord(ProtectedResource):
|
||||||
|
def __init__(self, record_id: str, table_name: str, owner: User):
|
||||||
|
self.type = f"sql_record::{table_name}"
|
||||||
|
self.identifier = record_id
|
||||||
|
self.owner = owner
|
||||||
|
|
||||||
|
|
||||||
|
class AuthorizedSqlStore:
|
||||||
|
"""
|
||||||
|
Authorization layer for SqlStore that provides access control functionality.
|
||||||
|
|
||||||
|
This class composes a base SqlStore and adds authorization methods that handle
|
||||||
|
access control policies, user attribute capture, and SQL filtering optimization.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, sql_store: SqlStore, policy: list[AccessRule]):
|
||||||
|
"""
|
||||||
|
Initialize the authorization layer.
|
||||||
|
|
||||||
|
:param sql_store: Base SqlStore implementation to wrap
|
||||||
|
:param policy: Access control policy to use for authorization
|
||||||
|
"""
|
||||||
|
self.sql_store = sql_store
|
||||||
|
self.policy = policy
|
||||||
|
self._detect_database_type()
|
||||||
|
self._validate_sql_optimized_policy()
|
||||||
|
|
||||||
|
def _detect_database_type(self) -> None:
|
||||||
|
"""Detect the database type from the underlying SQL store."""
|
||||||
|
if not hasattr(self.sql_store, "config"):
|
||||||
|
raise ValueError("SqlStore must have a config attribute to be used with AuthorizedSqlStore")
|
||||||
|
|
||||||
|
self.database_type = self.sql_store.config.type.value
|
||||||
|
if self.database_type not in [StorageBackendType.SQL_POSTGRES.value, StorageBackendType.SQL_SQLITE.value]:
|
||||||
|
raise ValueError(f"Unsupported database type: {self.database_type}")
|
||||||
|
|
||||||
|
def _validate_sql_optimized_policy(self) -> None:
|
||||||
|
"""Validate that SQL_OPTIMIZED_POLICY matches the actual default_policy().
|
||||||
|
|
||||||
|
This ensures that if default_policy() changes, we detect the mismatch and
|
||||||
|
can update our SQL filtering logic accordingly.
|
||||||
|
"""
|
||||||
|
actual_default = default_policy()
|
||||||
|
|
||||||
|
if SQL_OPTIMIZED_POLICY != actual_default:
|
||||||
|
logger.warning(
|
||||||
|
f"SQL_OPTIMIZED_POLICY does not match default_policy(). "
|
||||||
|
f"SQL filtering will use conservative mode. "
|
||||||
|
f"Expected: {SQL_OPTIMIZED_POLICY}, Got: {actual_default}",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def create_table(self, table: str, schema: Mapping[str, ColumnType | ColumnDefinition]) -> None:
|
||||||
|
"""Create a table with built-in access control support."""
|
||||||
|
|
||||||
|
enhanced_schema = dict(schema)
|
||||||
|
if "access_attributes" not in enhanced_schema:
|
||||||
|
enhanced_schema["access_attributes"] = ColumnType.JSON
|
||||||
|
if "owner_principal" not in enhanced_schema:
|
||||||
|
enhanced_schema["owner_principal"] = ColumnType.STRING
|
||||||
|
|
||||||
|
await self.sql_store.create_table(table, enhanced_schema)
|
||||||
|
await self.sql_store.add_column_if_not_exists(table, "access_attributes", ColumnType.JSON)
|
||||||
|
await self.sql_store.add_column_if_not_exists(table, "owner_principal", ColumnType.STRING)
|
||||||
|
|
||||||
|
async def insert(self, table: str, data: Mapping[str, Any] | Sequence[Mapping[str, Any]]) -> None:
|
||||||
|
"""Insert a row or batch of rows with automatic access control attribute capture."""
|
||||||
|
current_user = get_authenticated_user()
|
||||||
|
enhanced_data: Mapping[str, Any] | Sequence[Mapping[str, Any]]
|
||||||
|
if isinstance(data, Mapping):
|
||||||
|
enhanced_data = _enhance_item_with_access_control(data, current_user)
|
||||||
|
else:
|
||||||
|
enhanced_data = [_enhance_item_with_access_control(item, current_user) for item in data]
|
||||||
|
await self.sql_store.insert(table, enhanced_data)
|
||||||
|
|
||||||
|
async def upsert(
|
||||||
|
self,
|
||||||
|
table: str,
|
||||||
|
data: Mapping[str, Any],
|
||||||
|
conflict_columns: list[str],
|
||||||
|
update_columns: list[str] | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Upsert a row with automatic access control attribute capture."""
|
||||||
|
current_user = get_authenticated_user()
|
||||||
|
enhanced_data = _enhance_item_with_access_control(data, current_user)
|
||||||
|
await self.sql_store.upsert(
|
||||||
|
table=table,
|
||||||
|
data=enhanced_data,
|
||||||
|
conflict_columns=conflict_columns,
|
||||||
|
update_columns=update_columns,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def fetch_all(
|
||||||
|
self,
|
||||||
|
table: str,
|
||||||
|
where: Mapping[str, Any] | None = None,
|
||||||
|
limit: int | None = None,
|
||||||
|
order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None,
|
||||||
|
cursor: tuple[str, str] | None = None,
|
||||||
|
) -> PaginatedResponse:
|
||||||
|
"""Fetch all rows with automatic access control filtering."""
|
||||||
|
access_where = self._build_access_control_where_clause(self.policy)
|
||||||
|
rows = await self.sql_store.fetch_all(
|
||||||
|
table=table,
|
||||||
|
where=where,
|
||||||
|
where_sql=access_where,
|
||||||
|
limit=limit,
|
||||||
|
order_by=order_by,
|
||||||
|
cursor=cursor,
|
||||||
|
)
|
||||||
|
|
||||||
|
current_user = get_authenticated_user()
|
||||||
|
filtered_rows = []
|
||||||
|
|
||||||
|
for row in rows.data:
|
||||||
|
stored_access_attrs = row.get("access_attributes")
|
||||||
|
stored_owner_principal = row.get("owner_principal") or ""
|
||||||
|
|
||||||
|
record_id = row.get("id", "unknown")
|
||||||
|
sql_record = SqlRecord(
|
||||||
|
str(record_id), table, User(principal=stored_owner_principal, attributes=stored_access_attrs)
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_action_allowed(self.policy, Action.READ, sql_record, current_user):
|
||||||
|
filtered_rows.append(row)
|
||||||
|
|
||||||
|
return PaginatedResponse(
|
||||||
|
data=filtered_rows,
|
||||||
|
has_more=rows.has_more,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def fetch_one(
|
||||||
|
self,
|
||||||
|
table: str,
|
||||||
|
where: Mapping[str, Any] | None = None,
|
||||||
|
order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None,
|
||||||
|
) -> dict[str, Any] | None:
|
||||||
|
"""Fetch one row with automatic access control checking."""
|
||||||
|
results = await self.fetch_all(
|
||||||
|
table=table,
|
||||||
|
where=where,
|
||||||
|
limit=1,
|
||||||
|
order_by=order_by,
|
||||||
|
)
|
||||||
|
|
||||||
|
return results.data[0] if results.data else None
|
||||||
|
|
||||||
|
async def update(self, table: str, data: Mapping[str, Any], where: Mapping[str, Any]) -> None:
|
||||||
|
"""Update rows with automatic access control attribute capture."""
|
||||||
|
enhanced_data = dict(data)
|
||||||
|
|
||||||
|
current_user = get_authenticated_user()
|
||||||
|
if current_user:
|
||||||
|
enhanced_data["owner_principal"] = current_user.principal
|
||||||
|
enhanced_data["access_attributes"] = current_user.attributes
|
||||||
|
else:
|
||||||
|
# IMPORTANT: Use empty string for owner_principal to match public access filter
|
||||||
|
enhanced_data["owner_principal"] = ""
|
||||||
|
enhanced_data["access_attributes"] = None # Will serialize as JSON null
|
||||||
|
|
||||||
|
await self.sql_store.update(table, enhanced_data, where)
|
||||||
|
|
||||||
|
async def delete(self, table: str, where: Mapping[str, Any]) -> None:
|
||||||
|
"""Delete rows with automatic access control filtering."""
|
||||||
|
await self.sql_store.delete(table, where)
|
||||||
|
|
||||||
|
def _build_access_control_where_clause(self, policy: list[AccessRule]) -> str:
|
||||||
|
"""Build SQL WHERE clause for access control filtering.
|
||||||
|
|
||||||
|
Only applies SQL filtering for the default policy to ensure correctness.
|
||||||
|
For custom policies, uses conservative filtering to avoid blocking legitimate access.
|
||||||
|
"""
|
||||||
|
current_user = get_authenticated_user()
|
||||||
|
|
||||||
|
if not policy or policy == SQL_OPTIMIZED_POLICY:
|
||||||
|
return self._build_default_policy_where_clause(current_user)
|
||||||
|
else:
|
||||||
|
return self._build_conservative_where_clause()
|
||||||
|
|
||||||
|
def _json_extract(self, column: str, path: str) -> str:
|
||||||
|
"""Extract JSON value (keeping JSON type).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
column: The JSON column name
|
||||||
|
path: The JSON path (e.g., 'roles', 'teams')
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
SQL expression to extract JSON value
|
||||||
|
"""
|
||||||
|
if self.database_type == StorageBackendType.SQL_POSTGRES.value:
|
||||||
|
return f"{column}->'{path}'"
|
||||||
|
elif self.database_type == StorageBackendType.SQL_SQLITE.value:
|
||||||
|
return f"JSON_EXTRACT({column}, '$.{path}')"
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported database type: {self.database_type}")
|
||||||
|
|
||||||
|
def _json_extract_text(self, column: str, path: str) -> str:
|
||||||
|
"""Extract JSON value as text.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
column: The JSON column name
|
||||||
|
path: The JSON path (e.g., 'roles', 'teams')
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
SQL expression to extract JSON value as text
|
||||||
|
"""
|
||||||
|
if self.database_type == StorageBackendType.SQL_POSTGRES.value:
|
||||||
|
return f"{column}->>'{path}'"
|
||||||
|
elif self.database_type == StorageBackendType.SQL_SQLITE.value:
|
||||||
|
return f"JSON_EXTRACT({column}, '$.{path}')"
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported database type: {self.database_type}")
|
||||||
|
|
||||||
|
def _get_public_access_conditions(self) -> list[str]:
|
||||||
|
"""Get the SQL conditions for public access.
|
||||||
|
|
||||||
|
Public records are those with:
|
||||||
|
- owner_principal = '' (empty string)
|
||||||
|
- access_attributes is either SQL NULL or JSON null
|
||||||
|
|
||||||
|
Note: Different databases serialize None differently:
|
||||||
|
- SQLite: None → JSON null (text = 'null')
|
||||||
|
- Postgres: None → SQL NULL (IS NULL)
|
||||||
|
"""
|
||||||
|
conditions = ["owner_principal = ''"]
|
||||||
|
if self.database_type == StorageBackendType.SQL_POSTGRES.value:
|
||||||
|
# Accept both SQL NULL and JSON null for Postgres compatibility
|
||||||
|
# This handles both old rows (SQL NULL) and new rows (JSON null)
|
||||||
|
conditions.append("(access_attributes IS NULL OR access_attributes::text = 'null')")
|
||||||
|
elif self.database_type == StorageBackendType.SQL_SQLITE.value:
|
||||||
|
# SQLite serializes None as JSON null
|
||||||
|
conditions.append("(access_attributes IS NULL OR access_attributes = 'null')")
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported database type: {self.database_type}")
|
||||||
|
return conditions
|
||||||
|
|
||||||
|
def _build_default_policy_where_clause(self, current_user: User | None) -> str:
|
||||||
|
"""Build SQL WHERE clause for the default policy.
|
||||||
|
|
||||||
|
Default policy: permit all actions when user in owners [roles, teams, projects, namespaces]
|
||||||
|
This means user must match ALL attribute categories that exist in the resource.
|
||||||
|
"""
|
||||||
|
base_conditions = self._get_public_access_conditions()
|
||||||
|
user_attr_conditions = []
|
||||||
|
|
||||||
|
if current_user and current_user.attributes:
|
||||||
|
for attr_key, user_values in current_user.attributes.items():
|
||||||
|
if user_values:
|
||||||
|
value_conditions = []
|
||||||
|
for value in user_values:
|
||||||
|
# Check if JSON array contains the value
|
||||||
|
escaped_value = value.replace("'", "''")
|
||||||
|
json_text = self._json_extract_text("access_attributes", attr_key)
|
||||||
|
value_conditions.append(f"({json_text} LIKE '%\"{escaped_value}\"%')")
|
||||||
|
|
||||||
|
if value_conditions:
|
||||||
|
# Check if the category is missing (NULL)
|
||||||
|
category_missing = f"{self._json_extract('access_attributes', attr_key)} IS NULL"
|
||||||
|
user_matches_category = f"({' OR '.join(value_conditions)})"
|
||||||
|
user_attr_conditions.append(f"({category_missing} OR {user_matches_category})")
|
||||||
|
|
||||||
|
if user_attr_conditions:
|
||||||
|
all_requirements_met = f"({' AND '.join(user_attr_conditions)})"
|
||||||
|
base_conditions.append(all_requirements_met)
|
||||||
|
|
||||||
|
return f"({' OR '.join(base_conditions)})"
|
||||||
|
|
||||||
|
def _build_conservative_where_clause(self) -> str:
|
||||||
|
"""Conservative SQL filtering for custom policies.
|
||||||
|
|
||||||
|
Only filters records we're 100% certain would be denied by any reasonable policy.
|
||||||
|
"""
|
||||||
|
current_user = get_authenticated_user()
|
||||||
|
|
||||||
|
if not current_user:
|
||||||
|
# Only allow public records
|
||||||
|
base_conditions = self._get_public_access_conditions()
|
||||||
|
return f"({' OR '.join(base_conditions)})"
|
||||||
|
|
||||||
|
return "1=1"
|
||||||
374
src/llama_stack/core/storage/sqlstore/sqlalchemy_sqlstore.py
Normal file
374
src/llama_stack/core/storage/sqlstore/sqlalchemy_sqlstore.py
Normal file
|
|
@ -0,0 +1,374 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
from collections.abc import Mapping, Sequence
|
||||||
|
from typing import Any, Literal, cast
|
||||||
|
|
||||||
|
from sqlalchemy import (
|
||||||
|
JSON,
|
||||||
|
Boolean,
|
||||||
|
Column,
|
||||||
|
DateTime,
|
||||||
|
Float,
|
||||||
|
Integer,
|
||||||
|
MetaData,
|
||||||
|
String,
|
||||||
|
Table,
|
||||||
|
Text,
|
||||||
|
event,
|
||||||
|
inspect,
|
||||||
|
select,
|
||||||
|
text,
|
||||||
|
)
|
||||||
|
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
||||||
|
from sqlalchemy.ext.asyncio.engine import AsyncEngine
|
||||||
|
from sqlalchemy.sql.elements import ColumnElement
|
||||||
|
|
||||||
|
from llama_stack.core.storage.datatypes import SqlAlchemySqlStoreConfig
|
||||||
|
from llama_stack.log import get_logger
|
||||||
|
from llama_stack_api import PaginatedResponse
|
||||||
|
|
||||||
|
from .api import ColumnDefinition, ColumnType, SqlStore
|
||||||
|
|
||||||
|
logger = get_logger(name=__name__, category="providers::utils")
|
||||||
|
|
||||||
|
TYPE_MAPPING: dict[ColumnType, Any] = {
|
||||||
|
ColumnType.INTEGER: Integer,
|
||||||
|
ColumnType.STRING: String,
|
||||||
|
ColumnType.FLOAT: Float,
|
||||||
|
ColumnType.BOOLEAN: Boolean,
|
||||||
|
ColumnType.DATETIME: DateTime,
|
||||||
|
ColumnType.TEXT: Text,
|
||||||
|
ColumnType.JSON: JSON,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _build_where_expr(column: ColumnElement, value: Any) -> ColumnElement:
|
||||||
|
"""Return a SQLAlchemy expression for a where condition.
|
||||||
|
|
||||||
|
`value` may be a simple scalar (equality) or a mapping like {">": 123}.
|
||||||
|
The returned expression is a SQLAlchemy ColumnElement usable in query.where(...).
|
||||||
|
"""
|
||||||
|
if isinstance(value, Mapping):
|
||||||
|
if len(value) != 1:
|
||||||
|
raise ValueError(f"Operator mapping must have a single operator, got: {value}")
|
||||||
|
op, operand = next(iter(value.items()))
|
||||||
|
if op == "==" or op == "=":
|
||||||
|
return cast(ColumnElement[Any], column == operand)
|
||||||
|
if op == ">":
|
||||||
|
return cast(ColumnElement[Any], column > operand)
|
||||||
|
if op == "<":
|
||||||
|
return cast(ColumnElement[Any], column < operand)
|
||||||
|
if op == ">=":
|
||||||
|
return cast(ColumnElement[Any], column >= operand)
|
||||||
|
if op == "<=":
|
||||||
|
return cast(ColumnElement[Any], column <= operand)
|
||||||
|
raise ValueError(f"Unsupported operator '{op}' in where mapping")
|
||||||
|
return cast(ColumnElement[Any], column == value)
|
||||||
|
|
||||||
|
|
||||||
|
class SqlAlchemySqlStoreImpl(SqlStore):
|
||||||
|
def __init__(self, config: SqlAlchemySqlStoreConfig):
|
||||||
|
self.config = config
|
||||||
|
self._is_sqlite_backend = "sqlite" in self.config.engine_str
|
||||||
|
self.async_session = async_sessionmaker(self.create_engine())
|
||||||
|
self.metadata = MetaData()
|
||||||
|
|
||||||
|
def create_engine(self) -> AsyncEngine:
|
||||||
|
# Configure connection args for better concurrency support
|
||||||
|
connect_args = {}
|
||||||
|
if self._is_sqlite_backend:
|
||||||
|
# SQLite-specific optimizations for concurrent access
|
||||||
|
# With WAL mode, most locks resolve in milliseconds, but allow up to 5s for edge cases
|
||||||
|
connect_args["timeout"] = 5.0
|
||||||
|
connect_args["check_same_thread"] = False # Allow usage across asyncio tasks
|
||||||
|
|
||||||
|
engine = create_async_engine(
|
||||||
|
self.config.engine_str,
|
||||||
|
pool_pre_ping=True,
|
||||||
|
connect_args=connect_args,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Enable WAL mode for SQLite to support concurrent readers and writers
|
||||||
|
if self._is_sqlite_backend:
|
||||||
|
|
||||||
|
@event.listens_for(engine.sync_engine, "connect")
|
||||||
|
def set_sqlite_pragma(dbapi_conn, connection_record):
|
||||||
|
cursor = dbapi_conn.cursor()
|
||||||
|
# Enable Write-Ahead Logging for better concurrency
|
||||||
|
cursor.execute("PRAGMA journal_mode=WAL")
|
||||||
|
# Set busy timeout to 5 seconds (retry instead of immediate failure)
|
||||||
|
# With WAL mode, locks should be brief; if we hit 5s there's a bigger issue
|
||||||
|
cursor.execute("PRAGMA busy_timeout=5000")
|
||||||
|
# Use NORMAL synchronous mode for better performance (still safe with WAL)
|
||||||
|
cursor.execute("PRAGMA synchronous=NORMAL")
|
||||||
|
cursor.close()
|
||||||
|
|
||||||
|
return engine
|
||||||
|
|
||||||
|
async def create_table(
|
||||||
|
self,
|
||||||
|
table: str,
|
||||||
|
schema: Mapping[str, ColumnType | ColumnDefinition],
|
||||||
|
) -> None:
|
||||||
|
if not schema:
|
||||||
|
raise ValueError(f"No columns defined for table '{table}'.")
|
||||||
|
|
||||||
|
sqlalchemy_columns: list[Column] = []
|
||||||
|
|
||||||
|
for col_name, col_props in schema.items():
|
||||||
|
col_type = None
|
||||||
|
is_primary_key = False
|
||||||
|
is_nullable = True
|
||||||
|
|
||||||
|
if isinstance(col_props, ColumnType):
|
||||||
|
col_type = col_props
|
||||||
|
elif isinstance(col_props, ColumnDefinition):
|
||||||
|
col_type = col_props.type
|
||||||
|
is_primary_key = col_props.primary_key
|
||||||
|
is_nullable = col_props.nullable
|
||||||
|
|
||||||
|
sqlalchemy_type = TYPE_MAPPING.get(col_type)
|
||||||
|
if not sqlalchemy_type:
|
||||||
|
raise ValueError(f"Unsupported column type '{col_type}' for column '{col_name}'.")
|
||||||
|
|
||||||
|
sqlalchemy_columns.append(
|
||||||
|
Column(col_name, sqlalchemy_type, primary_key=is_primary_key, nullable=is_nullable)
|
||||||
|
)
|
||||||
|
|
||||||
|
if table not in self.metadata.tables:
|
||||||
|
sqlalchemy_table = Table(table, self.metadata, *sqlalchemy_columns)
|
||||||
|
else:
|
||||||
|
sqlalchemy_table = self.metadata.tables[table]
|
||||||
|
|
||||||
|
engine = self.create_engine()
|
||||||
|
async with engine.begin() as conn:
|
||||||
|
await conn.run_sync(self.metadata.create_all, tables=[sqlalchemy_table], checkfirst=True)
|
||||||
|
|
||||||
|
async def insert(self, table: str, data: Mapping[str, Any] | Sequence[Mapping[str, Any]]) -> None:
|
||||||
|
async with self.async_session() as session:
|
||||||
|
await session.execute(self.metadata.tables[table].insert(), data)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
async def upsert(
|
||||||
|
self,
|
||||||
|
table: str,
|
||||||
|
data: Mapping[str, Any],
|
||||||
|
conflict_columns: list[str],
|
||||||
|
update_columns: list[str] | None = None,
|
||||||
|
) -> None:
|
||||||
|
table_obj = self.metadata.tables[table]
|
||||||
|
dialect_insert = self._get_dialect_insert(table_obj)
|
||||||
|
insert_stmt = dialect_insert.values(**data)
|
||||||
|
|
||||||
|
if update_columns is None:
|
||||||
|
update_columns = [col for col in data.keys() if col not in conflict_columns]
|
||||||
|
|
||||||
|
update_mapping = {col: getattr(insert_stmt.excluded, col) for col in update_columns}
|
||||||
|
conflict_cols = [table_obj.c[col] for col in conflict_columns]
|
||||||
|
|
||||||
|
stmt = insert_stmt.on_conflict_do_update(index_elements=conflict_cols, set_=update_mapping)
|
||||||
|
|
||||||
|
async with self.async_session() as session:
|
||||||
|
await session.execute(stmt)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
async def fetch_all(
|
||||||
|
self,
|
||||||
|
table: str,
|
||||||
|
where: Mapping[str, Any] | None = None,
|
||||||
|
where_sql: str | None = None,
|
||||||
|
limit: int | None = None,
|
||||||
|
order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None,
|
||||||
|
cursor: tuple[str, str] | None = None,
|
||||||
|
) -> PaginatedResponse:
|
||||||
|
async with self.async_session() as session:
|
||||||
|
table_obj = self.metadata.tables[table]
|
||||||
|
query = select(table_obj)
|
||||||
|
|
||||||
|
if where:
|
||||||
|
for key, value in where.items():
|
||||||
|
query = query.where(_build_where_expr(table_obj.c[key], value))
|
||||||
|
|
||||||
|
if where_sql:
|
||||||
|
query = query.where(text(where_sql))
|
||||||
|
|
||||||
|
# Handle cursor-based pagination
|
||||||
|
if cursor:
|
||||||
|
# Validate cursor tuple format
|
||||||
|
if not isinstance(cursor, tuple) or len(cursor) != 2:
|
||||||
|
raise ValueError(f"Cursor must be a tuple of (key_column, cursor_id), got: {cursor}")
|
||||||
|
|
||||||
|
# Require order_by for cursor pagination
|
||||||
|
if not order_by:
|
||||||
|
raise ValueError("order_by is required when using cursor pagination")
|
||||||
|
|
||||||
|
# Only support single-column ordering for cursor pagination
|
||||||
|
if len(order_by) != 1:
|
||||||
|
raise ValueError(
|
||||||
|
f"Cursor pagination only supports single-column ordering, got {len(order_by)} columns"
|
||||||
|
)
|
||||||
|
|
||||||
|
cursor_key_column, cursor_id = cursor
|
||||||
|
order_column, order_direction = order_by[0]
|
||||||
|
|
||||||
|
# Verify cursor_key_column exists
|
||||||
|
if cursor_key_column not in table_obj.c:
|
||||||
|
raise ValueError(f"Cursor key column '{cursor_key_column}' not found in table '{table}'")
|
||||||
|
|
||||||
|
# Get cursor value for the order column
|
||||||
|
cursor_query = select(table_obj.c[order_column]).where(table_obj.c[cursor_key_column] == cursor_id)
|
||||||
|
cursor_result = await session.execute(cursor_query)
|
||||||
|
cursor_row = cursor_result.fetchone()
|
||||||
|
|
||||||
|
if not cursor_row:
|
||||||
|
raise ValueError(f"Record with {cursor_key_column}='{cursor_id}' not found in table '{table}'")
|
||||||
|
|
||||||
|
cursor_value = cursor_row[0]
|
||||||
|
|
||||||
|
# Apply cursor condition based on sort direction
|
||||||
|
if order_direction == "desc":
|
||||||
|
query = query.where(table_obj.c[order_column] < cursor_value)
|
||||||
|
else:
|
||||||
|
query = query.where(table_obj.c[order_column] > cursor_value)
|
||||||
|
|
||||||
|
# Apply ordering
|
||||||
|
if order_by:
|
||||||
|
if not isinstance(order_by, list):
|
||||||
|
raise ValueError(
|
||||||
|
f"order_by must be a list of tuples (column, order={['asc', 'desc']}), got {order_by}"
|
||||||
|
)
|
||||||
|
for order in order_by:
|
||||||
|
if not isinstance(order, tuple):
|
||||||
|
raise ValueError(
|
||||||
|
f"order_by must be a list of tuples (column, order={['asc', 'desc']}), got {order_by}"
|
||||||
|
)
|
||||||
|
name, order_type = order
|
||||||
|
if name not in table_obj.c:
|
||||||
|
raise ValueError(f"Column '{name}' not found in table '{table}'")
|
||||||
|
if order_type == "asc":
|
||||||
|
query = query.order_by(table_obj.c[name].asc())
|
||||||
|
elif order_type == "desc":
|
||||||
|
query = query.order_by(table_obj.c[name].desc())
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid order '{order_type}' for column '{name}'")
|
||||||
|
|
||||||
|
# Fetch limit + 1 to determine has_more
|
||||||
|
fetch_limit = limit
|
||||||
|
if limit:
|
||||||
|
fetch_limit = limit + 1
|
||||||
|
|
||||||
|
if fetch_limit:
|
||||||
|
query = query.limit(fetch_limit)
|
||||||
|
|
||||||
|
result = await session.execute(query)
|
||||||
|
# Iterate directly - if no rows, list comprehension yields empty list
|
||||||
|
rows = [dict(row._mapping) for row in result]
|
||||||
|
|
||||||
|
# Always return pagination result
|
||||||
|
has_more = False
|
||||||
|
if limit and len(rows) > limit:
|
||||||
|
has_more = True
|
||||||
|
rows = rows[:limit]
|
||||||
|
|
||||||
|
return PaginatedResponse(data=rows, has_more=has_more)
|
||||||
|
|
||||||
|
async def fetch_one(
|
||||||
|
self,
|
||||||
|
table: str,
|
||||||
|
where: Mapping[str, Any] | None = None,
|
||||||
|
where_sql: str | None = None,
|
||||||
|
order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None,
|
||||||
|
) -> dict[str, Any] | None:
|
||||||
|
result = await self.fetch_all(table, where, where_sql, limit=1, order_by=order_by)
|
||||||
|
if not result.data:
|
||||||
|
return None
|
||||||
|
return result.data[0]
|
||||||
|
|
||||||
|
async def update(
|
||||||
|
self,
|
||||||
|
table: str,
|
||||||
|
data: Mapping[str, Any],
|
||||||
|
where: Mapping[str, Any],
|
||||||
|
) -> None:
|
||||||
|
if not where:
|
||||||
|
raise ValueError("where is required for update")
|
||||||
|
|
||||||
|
async with self.async_session() as session:
|
||||||
|
stmt = self.metadata.tables[table].update()
|
||||||
|
for key, value in where.items():
|
||||||
|
stmt = stmt.where(_build_where_expr(self.metadata.tables[table].c[key], value))
|
||||||
|
await session.execute(stmt, data)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
async def delete(self, table: str, where: Mapping[str, Any]) -> None:
|
||||||
|
if not where:
|
||||||
|
raise ValueError("where is required for delete")
|
||||||
|
|
||||||
|
async with self.async_session() as session:
|
||||||
|
stmt = self.metadata.tables[table].delete()
|
||||||
|
for key, value in where.items():
|
||||||
|
stmt = stmt.where(_build_where_expr(self.metadata.tables[table].c[key], value))
|
||||||
|
await session.execute(stmt)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
async def add_column_if_not_exists(
|
||||||
|
self,
|
||||||
|
table: str,
|
||||||
|
column_name: str,
|
||||||
|
column_type: ColumnType,
|
||||||
|
nullable: bool = True,
|
||||||
|
) -> None:
|
||||||
|
"""Add a column to an existing table if the column doesn't already exist."""
|
||||||
|
engine = self.create_engine()
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with engine.begin() as conn:
|
||||||
|
|
||||||
|
def check_column_exists(sync_conn):
|
||||||
|
inspector = inspect(sync_conn)
|
||||||
|
|
||||||
|
table_names = inspector.get_table_names()
|
||||||
|
if table not in table_names:
|
||||||
|
return False, False # table doesn't exist, column doesn't exist
|
||||||
|
|
||||||
|
existing_columns = inspector.get_columns(table)
|
||||||
|
column_names = [col["name"] for col in existing_columns]
|
||||||
|
|
||||||
|
return True, column_name in column_names # table exists, column exists or not
|
||||||
|
|
||||||
|
table_exists, column_exists = await conn.run_sync(check_column_exists)
|
||||||
|
if not table_exists or column_exists:
|
||||||
|
return
|
||||||
|
|
||||||
|
sqlalchemy_type = TYPE_MAPPING.get(column_type)
|
||||||
|
if not sqlalchemy_type:
|
||||||
|
raise ValueError(f"Unsupported column type '{column_type}' for column '{column_name}'.")
|
||||||
|
|
||||||
|
# Create the ALTER TABLE statement
|
||||||
|
# Note: We need to get the dialect-specific type name
|
||||||
|
dialect = engine.dialect
|
||||||
|
type_impl = sqlalchemy_type()
|
||||||
|
compiled_type = type_impl.compile(dialect=dialect)
|
||||||
|
|
||||||
|
nullable_clause = "" if nullable else " NOT NULL"
|
||||||
|
add_column_sql = text(f"ALTER TABLE {table} ADD COLUMN {column_name} {compiled_type}{nullable_clause}")
|
||||||
|
|
||||||
|
await conn.execute(add_column_sql)
|
||||||
|
except Exception as e:
|
||||||
|
# If any error occurs during migration, log it but don't fail
|
||||||
|
# The table creation will handle adding the column
|
||||||
|
logger.error(f"Error adding column {column_name} to table {table}: {e}")
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _get_dialect_insert(self, table: Table):
|
||||||
|
if self._is_sqlite_backend:
|
||||||
|
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
|
||||||
|
|
||||||
|
return sqlite_insert(table)
|
||||||
|
else:
|
||||||
|
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||||
|
|
||||||
|
return pg_insert(table)
|
||||||
88
src/llama_stack/core/storage/sqlstore/sqlstore.py
Normal file
88
src/llama_stack/core/storage/sqlstore/sqlstore.py
Normal file
|
|
@ -0,0 +1,88 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from threading import Lock
|
||||||
|
from typing import Annotated, cast
|
||||||
|
|
||||||
|
from pydantic import Field
|
||||||
|
|
||||||
|
from llama_stack.core.storage.datatypes import (
|
||||||
|
PostgresSqlStoreConfig,
|
||||||
|
SqliteSqlStoreConfig,
|
||||||
|
SqlStoreReference,
|
||||||
|
StorageBackendConfig,
|
||||||
|
StorageBackendType,
|
||||||
|
)
|
||||||
|
|
||||||
|
from .api import SqlStore
|
||||||
|
|
||||||
|
sql_store_pip_packages = ["sqlalchemy[asyncio]", "aiosqlite", "asyncpg"]
|
||||||
|
|
||||||
|
_SQLSTORE_BACKENDS: dict[str, StorageBackendConfig] = {}
|
||||||
|
_SQLSTORE_INSTANCES: dict[str, SqlStore] = {}
|
||||||
|
_SQLSTORE_LOCKS: dict[str, Lock] = {}
|
||||||
|
|
||||||
|
|
||||||
|
SqlStoreConfig = Annotated[
|
||||||
|
SqliteSqlStoreConfig | PostgresSqlStoreConfig,
|
||||||
|
Field(discriminator="type"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def get_pip_packages(store_config: dict | SqlStoreConfig) -> list[str]:
|
||||||
|
"""Get pip packages for SQL store config, handling both dict and object cases."""
|
||||||
|
if isinstance(store_config, dict):
|
||||||
|
store_type = store_config.get("type")
|
||||||
|
if store_type == StorageBackendType.SQL_SQLITE.value:
|
||||||
|
return SqliteSqlStoreConfig.pip_packages()
|
||||||
|
elif store_type == StorageBackendType.SQL_POSTGRES.value:
|
||||||
|
return PostgresSqlStoreConfig.pip_packages()
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown SQL store type: {store_type}")
|
||||||
|
else:
|
||||||
|
return store_config.pip_packages()
|
||||||
|
|
||||||
|
|
||||||
|
def sqlstore_impl(reference: SqlStoreReference) -> SqlStore:
|
||||||
|
backend_name = reference.backend
|
||||||
|
|
||||||
|
backend_config = _SQLSTORE_BACKENDS.get(backend_name)
|
||||||
|
if backend_config is None:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unknown SQL store backend '{backend_name}'. Registered backends: {sorted(_SQLSTORE_BACKENDS)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
existing = _SQLSTORE_INSTANCES.get(backend_name)
|
||||||
|
if existing:
|
||||||
|
return existing
|
||||||
|
|
||||||
|
lock = _SQLSTORE_LOCKS.setdefault(backend_name, Lock())
|
||||||
|
with lock:
|
||||||
|
existing = _SQLSTORE_INSTANCES.get(backend_name)
|
||||||
|
if existing:
|
||||||
|
return existing
|
||||||
|
|
||||||
|
if isinstance(backend_config, SqliteSqlStoreConfig | PostgresSqlStoreConfig):
|
||||||
|
from .sqlalchemy_sqlstore import SqlAlchemySqlStoreImpl
|
||||||
|
|
||||||
|
config = cast(SqliteSqlStoreConfig | PostgresSqlStoreConfig, backend_config).model_copy()
|
||||||
|
instance = SqlAlchemySqlStoreImpl(config)
|
||||||
|
_SQLSTORE_INSTANCES[backend_name] = instance
|
||||||
|
return instance
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown sqlstore type {backend_config.type}")
|
||||||
|
|
||||||
|
|
||||||
|
def register_sqlstore_backends(backends: dict[str, StorageBackendConfig]) -> None:
|
||||||
|
"""Register the set of available SQL store backends for reference resolution."""
|
||||||
|
global _SQLSTORE_BACKENDS
|
||||||
|
global _SQLSTORE_INSTANCES
|
||||||
|
|
||||||
|
_SQLSTORE_BACKENDS.clear()
|
||||||
|
_SQLSTORE_INSTANCES.clear()
|
||||||
|
_SQLSTORE_LOCKS.clear()
|
||||||
|
for name, cfg in backends.items():
|
||||||
|
_SQLSTORE_BACKENDS[name] = cfg
|
||||||
|
|
@ -13,7 +13,7 @@ import pydantic
|
||||||
from llama_stack.core.datatypes import RoutableObjectWithProvider
|
from llama_stack.core.datatypes import RoutableObjectWithProvider
|
||||||
from llama_stack.core.storage.datatypes import KVStoreReference
|
from llama_stack.core.storage.datatypes import KVStoreReference
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.utils.kvstore import KVStore, kvstore_impl
|
from llama_stack.core.storage.kvstore import KVStore, kvstore_impl
|
||||||
|
|
||||||
logger = get_logger(__name__, category="core::registry")
|
logger = get_logger(__name__, category="core::registry")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -35,8 +35,8 @@ from llama_stack.providers.remote.vector_io.pgvector.config import (
|
||||||
)
|
)
|
||||||
from llama_stack.providers.remote.vector_io.qdrant.config import QdrantVectorIOConfig
|
from llama_stack.providers.remote.vector_io.qdrant.config import QdrantVectorIOConfig
|
||||||
from llama_stack.providers.remote.vector_io.weaviate.config import WeaviateVectorIOConfig
|
from llama_stack.providers.remote.vector_io.weaviate.config import WeaviateVectorIOConfig
|
||||||
from llama_stack.providers.utils.kvstore.config import PostgresKVStoreConfig
|
from llama_stack.core.storage.kvstore.config import PostgresKVStoreConfig
|
||||||
from llama_stack.providers.utils.sqlstore.sqlstore import PostgresSqlStoreConfig
|
from llama_stack.core.storage.sqlstore.sqlstore import PostgresSqlStoreConfig
|
||||||
from llama_stack_api import RemoteProviderSpec
|
from llama_stack_api import RemoteProviderSpec
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -38,10 +38,10 @@ from llama_stack.core.storage.datatypes import (
|
||||||
from llama_stack.core.utils.dynamic import instantiate_class_type
|
from llama_stack.core.utils.dynamic import instantiate_class_type
|
||||||
from llama_stack.core.utils.image_types import LlamaStackImageType
|
from llama_stack.core.utils.image_types import LlamaStackImageType
|
||||||
from llama_stack.providers.utils.inference.model_registry import ProviderModelEntry
|
from llama_stack.providers.utils.inference.model_registry import ProviderModelEntry
|
||||||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
from llama_stack.core.storage.kvstore.config import SqliteKVStoreConfig
|
||||||
from llama_stack.providers.utils.kvstore.config import get_pip_packages as get_kv_pip_packages
|
from llama_stack.core.storage.kvstore.config import get_pip_packages as get_kv_pip_packages
|
||||||
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
|
from llama_stack.core.storage.sqlstore.sqlstore import SqliteSqlStoreConfig
|
||||||
from llama_stack.providers.utils.sqlstore.sqlstore import get_pip_packages as get_sql_pip_packages
|
from llama_stack.core.storage.sqlstore.sqlstore import get_pip_packages as get_sql_pip_packages
|
||||||
from llama_stack_api import DatasetPurpose, ModelType
|
from llama_stack_api import DatasetPurpose, ModelType
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,7 @@
|
||||||
|
|
||||||
from llama_stack.core.datatypes import AccessRule
|
from llama_stack.core.datatypes import AccessRule
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.utils.kvstore import InmemoryKVStoreImpl, kvstore_impl
|
from llama_stack.core.storage.kvstore import InmemoryKVStoreImpl, kvstore_impl
|
||||||
from llama_stack.providers.utils.responses.responses_store import ResponsesStore
|
from llama_stack.providers.utils.responses.responses_store import ResponsesStore
|
||||||
from llama_stack_api import (
|
from llama_stack_api import (
|
||||||
Agents,
|
Agents,
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,7 @@
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from llama_stack.core.datatypes import AccessRule, Api
|
from llama_stack.core.datatypes import AccessRule, Api
|
||||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
from llama_stack.core.storage.kvstore import kvstore_impl
|
||||||
from llama_stack_api import Files, Inference, Models
|
from llama_stack_api import Files, Inference, Models
|
||||||
|
|
||||||
from .batches import ReferenceBatchesImpl
|
from .batches import ReferenceBatchesImpl
|
||||||
|
|
|
||||||
|
|
@ -17,7 +17,7 @@ from openai.types.batch import BatchError, Errors
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.utils.kvstore import KVStore
|
from llama_stack.core.storage.kvstore import KVStore
|
||||||
from llama_stack_api import (
|
from llama_stack_api import (
|
||||||
Batches,
|
Batches,
|
||||||
BatchObject,
|
BatchObject,
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_uri
|
from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_uri
|
||||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
from llama_stack.core.storage.kvstore import kvstore_impl
|
||||||
from llama_stack.providers.utils.pagination import paginate_records
|
from llama_stack.providers.utils.pagination import paginate_records
|
||||||
from llama_stack_api import Dataset, DatasetIO, DatasetsProtocolPrivate, PaginatedResponse
|
from llama_stack_api import Dataset, DatasetIO, DatasetsProtocolPrivate, PaginatedResponse
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -9,7 +9,7 @@ from typing import Any
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from llama_stack.providers.utils.common.data_schema_validator import ColumnName
|
from llama_stack.providers.utils.common.data_schema_validator import ColumnName
|
||||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
from llama_stack.core.storage.kvstore import kvstore_impl
|
||||||
from llama_stack_api import (
|
from llama_stack_api import (
|
||||||
Agents,
|
Agents,
|
||||||
Benchmark,
|
Benchmark,
|
||||||
|
|
|
||||||
|
|
@ -15,9 +15,8 @@ from llama_stack.core.datatypes import AccessRule
|
||||||
from llama_stack.core.id_generation import generate_object_id
|
from llama_stack.core.id_generation import generate_object_id
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.utils.files.form_data import parse_expires_after
|
from llama_stack.providers.utils.files.form_data import parse_expires_after
|
||||||
from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType
|
from llama_stack.core.storage.sqlstore.authorized_sqlstore import AuthorizedSqlStore
|
||||||
from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore
|
from llama_stack.core.storage.sqlstore.sqlstore import sqlstore_impl
|
||||||
from llama_stack.providers.utils.sqlstore.sqlstore import sqlstore_impl
|
|
||||||
from llama_stack_api import (
|
from llama_stack_api import (
|
||||||
ExpiresAfter,
|
ExpiresAfter,
|
||||||
Files,
|
Files,
|
||||||
|
|
@ -28,6 +27,7 @@ from llama_stack_api import (
|
||||||
Order,
|
Order,
|
||||||
ResourceNotFoundError,
|
ResourceNotFoundError,
|
||||||
)
|
)
|
||||||
|
from llama_stack_api.internal.sqlstore import ColumnDefinition, ColumnType
|
||||||
|
|
||||||
from .config import LocalfsFilesImplConfig
|
from .config import LocalfsFilesImplConfig
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -15,8 +15,7 @@ import numpy as np
|
||||||
from numpy.typing import NDArray
|
from numpy.typing import NDArray
|
||||||
|
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
from llama_stack.core.storage.kvstore import kvstore_impl
|
||||||
from llama_stack.providers.utils.kvstore.api import KVStore
|
|
||||||
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
|
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
|
||||||
from llama_stack.providers.utils.memory.vector_store import ChunkForDeletion, EmbeddingIndex, VectorStoreWithIndex
|
from llama_stack.providers.utils.memory.vector_store import ChunkForDeletion, EmbeddingIndex, VectorStoreWithIndex
|
||||||
from llama_stack_api import (
|
from llama_stack_api import (
|
||||||
|
|
@ -32,6 +31,7 @@ from llama_stack_api import (
|
||||||
VectorStoreNotFoundError,
|
VectorStoreNotFoundError,
|
||||||
VectorStoresProtocolPrivate,
|
VectorStoresProtocolPrivate,
|
||||||
)
|
)
|
||||||
|
from llama_stack_api.internal.kvstore import KVStore
|
||||||
|
|
||||||
from .config import FaissVectorIOConfig
|
from .config import FaissVectorIOConfig
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -15,8 +15,7 @@ import sqlite_vec # type: ignore[import-untyped]
|
||||||
from numpy.typing import NDArray
|
from numpy.typing import NDArray
|
||||||
|
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
from llama_stack.core.storage.kvstore import kvstore_impl
|
||||||
from llama_stack.providers.utils.kvstore.api import KVStore
|
|
||||||
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
|
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
|
||||||
from llama_stack.providers.utils.memory.vector_store import (
|
from llama_stack.providers.utils.memory.vector_store import (
|
||||||
RERANKER_TYPE_RRF,
|
RERANKER_TYPE_RRF,
|
||||||
|
|
@ -35,6 +34,7 @@ from llama_stack_api import (
|
||||||
VectorStoreNotFoundError,
|
VectorStoreNotFoundError,
|
||||||
VectorStoresProtocolPrivate,
|
VectorStoresProtocolPrivate,
|
||||||
)
|
)
|
||||||
|
from llama_stack_api.internal.kvstore import KVStore
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="vector_io")
|
logger = get_logger(name=__name__, category="vector_io")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,7 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
|
||||||
from llama_stack.providers.utils.kvstore import kvstore_dependencies
|
from llama_stack.core.storage.kvstore import kvstore_dependencies
|
||||||
from llama_stack_api import (
|
from llama_stack_api import (
|
||||||
Api,
|
Api,
|
||||||
InlineProviderSpec,
|
InlineProviderSpec,
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from llama_stack.providers.utils.sqlstore.sqlstore import sql_store_pip_packages
|
from llama_stack.core.storage.sqlstore.sqlstore import sql_store_pip_packages
|
||||||
from llama_stack_api import Api, InlineProviderSpec, ProviderSpec, RemoteProviderSpec
|
from llama_stack_api import Api, InlineProviderSpec, ProviderSpec, RemoteProviderSpec
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from urllib.parse import parse_qs, urlparse
|
from urllib.parse import parse_qs, urlparse
|
||||||
|
|
||||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
from llama_stack.core.storage.kvstore import kvstore_impl
|
||||||
from llama_stack.providers.utils.pagination import paginate_records
|
from llama_stack.providers.utils.pagination import paginate_records
|
||||||
from llama_stack_api import Dataset, DatasetIO, DatasetsProtocolPrivate, PaginatedResponse
|
from llama_stack_api import Dataset, DatasetIO, DatasetsProtocolPrivate, PaginatedResponse
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -11,9 +11,8 @@ from fastapi import Depends, File, Form, Response, UploadFile
|
||||||
|
|
||||||
from llama_stack.core.datatypes import AccessRule
|
from llama_stack.core.datatypes import AccessRule
|
||||||
from llama_stack.providers.utils.files.form_data import parse_expires_after
|
from llama_stack.providers.utils.files.form_data import parse_expires_after
|
||||||
from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType
|
from llama_stack.core.storage.sqlstore.authorized_sqlstore import AuthorizedSqlStore
|
||||||
from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore
|
from llama_stack.core.storage.sqlstore.sqlstore import sqlstore_impl
|
||||||
from llama_stack.providers.utils.sqlstore.sqlstore import sqlstore_impl
|
|
||||||
from llama_stack_api import (
|
from llama_stack_api import (
|
||||||
ExpiresAfter,
|
ExpiresAfter,
|
||||||
Files,
|
Files,
|
||||||
|
|
@ -24,6 +23,7 @@ from llama_stack_api import (
|
||||||
Order,
|
Order,
|
||||||
ResourceNotFoundError,
|
ResourceNotFoundError,
|
||||||
)
|
)
|
||||||
|
from llama_stack_api.internal.sqlstore import ColumnDefinition, ColumnType
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
|
|
||||||
from .config import OpenAIFilesImplConfig
|
from .config import OpenAIFilesImplConfig
|
||||||
|
|
|
||||||
|
|
@ -20,9 +20,8 @@ if TYPE_CHECKING:
|
||||||
from llama_stack.core.datatypes import AccessRule
|
from llama_stack.core.datatypes import AccessRule
|
||||||
from llama_stack.core.id_generation import generate_object_id
|
from llama_stack.core.id_generation import generate_object_id
|
||||||
from llama_stack.providers.utils.files.form_data import parse_expires_after
|
from llama_stack.providers.utils.files.form_data import parse_expires_after
|
||||||
from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType
|
from llama_stack.core.storage.sqlstore.authorized_sqlstore import AuthorizedSqlStore
|
||||||
from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore
|
from llama_stack.core.storage.sqlstore.sqlstore import sqlstore_impl
|
||||||
from llama_stack.providers.utils.sqlstore.sqlstore import sqlstore_impl
|
|
||||||
from llama_stack_api import (
|
from llama_stack_api import (
|
||||||
ExpiresAfter,
|
ExpiresAfter,
|
||||||
Files,
|
Files,
|
||||||
|
|
@ -33,6 +32,7 @@ from llama_stack_api import (
|
||||||
Order,
|
Order,
|
||||||
ResourceNotFoundError,
|
ResourceNotFoundError,
|
||||||
)
|
)
|
||||||
|
from llama_stack_api.internal.sqlstore import ColumnDefinition, ColumnType
|
||||||
|
|
||||||
from .config import S3FilesImplConfig
|
from .config import S3FilesImplConfig
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -13,8 +13,7 @@ from numpy.typing import NDArray
|
||||||
|
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.inline.vector_io.chroma import ChromaVectorIOConfig as InlineChromaVectorIOConfig
|
from llama_stack.providers.inline.vector_io.chroma import ChromaVectorIOConfig as InlineChromaVectorIOConfig
|
||||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
from llama_stack.core.storage.kvstore import kvstore_impl
|
||||||
from llama_stack.providers.utils.kvstore.api import KVStore
|
|
||||||
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
|
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
|
||||||
from llama_stack.providers.utils.memory.vector_store import ChunkForDeletion, EmbeddingIndex, VectorStoreWithIndex
|
from llama_stack.providers.utils.memory.vector_store import ChunkForDeletion, EmbeddingIndex, VectorStoreWithIndex
|
||||||
from llama_stack_api import (
|
from llama_stack_api import (
|
||||||
|
|
@ -27,6 +26,7 @@ from llama_stack_api import (
|
||||||
VectorStore,
|
VectorStore,
|
||||||
VectorStoresProtocolPrivate,
|
VectorStoresProtocolPrivate,
|
||||||
)
|
)
|
||||||
|
from llama_stack_api.internal.kvstore import KVStore
|
||||||
|
|
||||||
from .config import ChromaVectorIOConfig as RemoteChromaVectorIOConfig
|
from .config import ChromaVectorIOConfig as RemoteChromaVectorIOConfig
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -13,8 +13,7 @@ from pymilvus import AnnSearchRequest, DataType, Function, FunctionType, MilvusC
|
||||||
|
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.inline.vector_io.milvus import MilvusVectorIOConfig as InlineMilvusVectorIOConfig
|
from llama_stack.providers.inline.vector_io.milvus import MilvusVectorIOConfig as InlineMilvusVectorIOConfig
|
||||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
from llama_stack.core.storage.kvstore import kvstore_impl
|
||||||
from llama_stack.providers.utils.kvstore.api import KVStore
|
|
||||||
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
|
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
|
||||||
from llama_stack.providers.utils.memory.vector_store import (
|
from llama_stack.providers.utils.memory.vector_store import (
|
||||||
RERANKER_TYPE_WEIGHTED,
|
RERANKER_TYPE_WEIGHTED,
|
||||||
|
|
@ -34,6 +33,7 @@ from llama_stack_api import (
|
||||||
VectorStoreNotFoundError,
|
VectorStoreNotFoundError,
|
||||||
VectorStoresProtocolPrivate,
|
VectorStoresProtocolPrivate,
|
||||||
)
|
)
|
||||||
|
from llama_stack_api.internal.kvstore import KVStore
|
||||||
|
|
||||||
from .config import MilvusVectorIOConfig as RemoteMilvusVectorIOConfig
|
from .config import MilvusVectorIOConfig as RemoteMilvusVectorIOConfig
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -15,8 +15,7 @@ from pydantic import BaseModel, TypeAdapter
|
||||||
|
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
|
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
|
||||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
from llama_stack.core.storage.kvstore import kvstore_impl
|
||||||
from llama_stack.providers.utils.kvstore.api import KVStore
|
|
||||||
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
|
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
|
||||||
from llama_stack.providers.utils.memory.vector_store import ChunkForDeletion, EmbeddingIndex, VectorStoreWithIndex
|
from llama_stack.providers.utils.memory.vector_store import ChunkForDeletion, EmbeddingIndex, VectorStoreWithIndex
|
||||||
from llama_stack.providers.utils.vector_io.vector_utils import WeightedInMemoryAggregator, sanitize_collection_name
|
from llama_stack.providers.utils.vector_io.vector_utils import WeightedInMemoryAggregator, sanitize_collection_name
|
||||||
|
|
@ -31,6 +30,7 @@ from llama_stack_api import (
|
||||||
VectorStoreNotFoundError,
|
VectorStoreNotFoundError,
|
||||||
VectorStoresProtocolPrivate,
|
VectorStoresProtocolPrivate,
|
||||||
)
|
)
|
||||||
|
from llama_stack_api.internal.kvstore import KVStore
|
||||||
|
|
||||||
from .config import PGVectorVectorIOConfig
|
from .config import PGVectorVectorIOConfig
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,7 @@ from qdrant_client.models import PointStruct
|
||||||
|
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.inline.vector_io.qdrant import QdrantVectorIOConfig as InlineQdrantVectorIOConfig
|
from llama_stack.providers.inline.vector_io.qdrant import QdrantVectorIOConfig as InlineQdrantVectorIOConfig
|
||||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
from llama_stack.core.storage.kvstore import kvstore_impl
|
||||||
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
|
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
|
||||||
from llama_stack.providers.utils.memory.vector_store import ChunkForDeletion, EmbeddingIndex, VectorStoreWithIndex
|
from llama_stack.providers.utils.memory.vector_store import ChunkForDeletion, EmbeddingIndex, VectorStoreWithIndex
|
||||||
from llama_stack_api import (
|
from llama_stack_api import (
|
||||||
|
|
|
||||||
|
|
@ -14,8 +14,7 @@ from weaviate.classes.query import Filter, HybridFusion
|
||||||
|
|
||||||
from llama_stack.core.request_headers import NeedsRequestProviderData
|
from llama_stack.core.request_headers import NeedsRequestProviderData
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
from llama_stack.core.storage.kvstore import kvstore_impl
|
||||||
from llama_stack.providers.utils.kvstore.api import KVStore
|
|
||||||
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
|
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
|
||||||
from llama_stack.providers.utils.memory.vector_store import (
|
from llama_stack.providers.utils.memory.vector_store import (
|
||||||
RERANKER_TYPE_RRF,
|
RERANKER_TYPE_RRF,
|
||||||
|
|
@ -35,6 +34,7 @@ from llama_stack_api import (
|
||||||
VectorStoreNotFoundError,
|
VectorStoreNotFoundError,
|
||||||
VectorStoresProtocolPrivate,
|
VectorStoresProtocolPrivate,
|
||||||
)
|
)
|
||||||
|
from llama_stack_api.internal.kvstore import KVStore
|
||||||
|
|
||||||
from .config import WeaviateVectorIOConfig
|
from .config import WeaviateVectorIOConfig
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -19,9 +19,9 @@ from llama_stack_api import (
|
||||||
Order,
|
Order,
|
||||||
)
|
)
|
||||||
|
|
||||||
from ..sqlstore.api import ColumnDefinition, ColumnType
|
from llama_stack.core.storage.sqlstore.api import ColumnDefinition, ColumnType
|
||||||
from ..sqlstore.authorized_sqlstore import AuthorizedSqlStore
|
from llama_stack.core.storage.sqlstore.authorized_sqlstore import AuthorizedSqlStore
|
||||||
from ..sqlstore.sqlstore import _SQLSTORE_BACKENDS, sqlstore_impl
|
from llama_stack.core.storage.sqlstore.sqlstore import _SQLSTORE_BACKENDS, sqlstore_impl
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="inference")
|
logger = get_logger(name=__name__, category="inference")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -4,4 +4,4 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from .kvstore import * # noqa: F401, F403
|
from llama_stack.core.storage.kvstore import * # noqa: F401,F403
|
||||||
|
|
|
||||||
|
|
@ -4,18 +4,4 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from datetime import datetime
|
from llama_stack.core.storage.kvstore.api import * # noqa: F401,F403
|
||||||
from typing import Protocol
|
|
||||||
|
|
||||||
|
|
||||||
class KVStore(Protocol):
|
|
||||||
# TODO: make the value type bytes instead of str
|
|
||||||
async def set(self, key: str, value: str, expiration: datetime | None = None) -> None: ...
|
|
||||||
|
|
||||||
async def get(self, key: str) -> str | None: ...
|
|
||||||
|
|
||||||
async def delete(self, key: str) -> None: ...
|
|
||||||
|
|
||||||
async def values_in_range(self, start_key: str, end_key: str) -> list[str]: ...
|
|
||||||
|
|
||||||
async def keys_in_range(self, start_key: str, end_key: str) -> list[str]: ...
|
|
||||||
|
|
|
||||||
|
|
@ -4,36 +4,4 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import Annotated
|
from llama_stack.core.storage.kvstore.config import * # noqa: F401,F403
|
||||||
|
|
||||||
from pydantic import Field
|
|
||||||
|
|
||||||
from llama_stack.core.storage.datatypes import (
|
|
||||||
MongoDBKVStoreConfig,
|
|
||||||
PostgresKVStoreConfig,
|
|
||||||
RedisKVStoreConfig,
|
|
||||||
SqliteKVStoreConfig,
|
|
||||||
StorageBackendType,
|
|
||||||
)
|
|
||||||
|
|
||||||
KVStoreConfig = Annotated[
|
|
||||||
RedisKVStoreConfig | SqliteKVStoreConfig | PostgresKVStoreConfig | MongoDBKVStoreConfig, Field(discriminator="type")
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def get_pip_packages(store_config: dict | KVStoreConfig) -> list[str]:
|
|
||||||
"""Get pip packages for KV store config, handling both dict and object cases."""
|
|
||||||
if isinstance(store_config, dict):
|
|
||||||
store_type = store_config.get("type")
|
|
||||||
if store_type == StorageBackendType.KV_SQLITE.value:
|
|
||||||
return SqliteKVStoreConfig.pip_packages()
|
|
||||||
elif store_type == StorageBackendType.KV_POSTGRES.value:
|
|
||||||
return PostgresKVStoreConfig.pip_packages()
|
|
||||||
elif store_type == StorageBackendType.KV_REDIS.value:
|
|
||||||
return RedisKVStoreConfig.pip_packages()
|
|
||||||
elif store_type == StorageBackendType.KV_MONGODB.value:
|
|
||||||
return MongoDBKVStoreConfig.pip_packages()
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unknown KV store type: {store_type}")
|
|
||||||
else:
|
|
||||||
return store_config.pip_packages()
|
|
||||||
|
|
|
||||||
|
|
@ -4,115 +4,4 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
from llama_stack.core.storage.kvstore.kvstore import * # noqa: F401,F403
|
||||||
# All rights reserved.
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
from collections import defaultdict
|
|
||||||
|
|
||||||
from llama_stack.core.storage.datatypes import KVStoreReference, StorageBackendConfig, StorageBackendType
|
|
||||||
|
|
||||||
from .api import KVStore
|
|
||||||
from .config import KVStoreConfig
|
|
||||||
|
|
||||||
|
|
||||||
def kvstore_dependencies():
|
|
||||||
"""
|
|
||||||
Returns all possible kvstore dependencies for registry/provider specifications.
|
|
||||||
|
|
||||||
NOTE: For specific kvstore implementations, use config.pip_packages instead.
|
|
||||||
This function returns the union of all dependencies for cases where the specific
|
|
||||||
kvstore type is not known at declaration time (e.g., provider registries).
|
|
||||||
"""
|
|
||||||
return ["aiosqlite", "psycopg2-binary", "redis", "pymongo"]
|
|
||||||
|
|
||||||
|
|
||||||
class InmemoryKVStoreImpl(KVStore):
|
|
||||||
def __init__(self):
|
|
||||||
self._store = {}
|
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def get(self, key: str) -> str | None:
|
|
||||||
return self._store.get(key)
|
|
||||||
|
|
||||||
async def set(self, key: str, value: str) -> None:
|
|
||||||
self._store[key] = value
|
|
||||||
|
|
||||||
async def values_in_range(self, start_key: str, end_key: str) -> list[str]:
|
|
||||||
return [self._store[key] for key in self._store.keys() if key >= start_key and key < end_key]
|
|
||||||
|
|
||||||
async def keys_in_range(self, start_key: str, end_key: str) -> list[str]:
|
|
||||||
"""Get all keys in the given range."""
|
|
||||||
return [key for key in self._store.keys() if key >= start_key and key < end_key]
|
|
||||||
|
|
||||||
async def delete(self, key: str) -> None:
|
|
||||||
del self._store[key]
|
|
||||||
|
|
||||||
|
|
||||||
_KVSTORE_BACKENDS: dict[str, KVStoreConfig] = {}
|
|
||||||
_KVSTORE_INSTANCES: dict[tuple[str, str], KVStore] = {}
|
|
||||||
_KVSTORE_LOCKS: defaultdict[tuple[str, str], asyncio.Lock] = defaultdict(asyncio.Lock)
|
|
||||||
|
|
||||||
|
|
||||||
def register_kvstore_backends(backends: dict[str, StorageBackendConfig]) -> None:
|
|
||||||
"""Register the set of available KV store backends for reference resolution."""
|
|
||||||
global _KVSTORE_BACKENDS
|
|
||||||
global _KVSTORE_INSTANCES
|
|
||||||
global _KVSTORE_LOCKS
|
|
||||||
|
|
||||||
_KVSTORE_BACKENDS.clear()
|
|
||||||
_KVSTORE_INSTANCES.clear()
|
|
||||||
_KVSTORE_LOCKS.clear()
|
|
||||||
for name, cfg in backends.items():
|
|
||||||
_KVSTORE_BACKENDS[name] = cfg
|
|
||||||
|
|
||||||
|
|
||||||
async def kvstore_impl(reference: KVStoreReference) -> KVStore:
|
|
||||||
backend_name = reference.backend
|
|
||||||
cache_key = (backend_name, reference.namespace)
|
|
||||||
|
|
||||||
existing = _KVSTORE_INSTANCES.get(cache_key)
|
|
||||||
if existing:
|
|
||||||
return existing
|
|
||||||
|
|
||||||
backend_config = _KVSTORE_BACKENDS.get(backend_name)
|
|
||||||
if backend_config is None:
|
|
||||||
raise ValueError(f"Unknown KVStore backend '{backend_name}'. Registered backends: {sorted(_KVSTORE_BACKENDS)}")
|
|
||||||
|
|
||||||
lock = _KVSTORE_LOCKS[cache_key]
|
|
||||||
async with lock:
|
|
||||||
existing = _KVSTORE_INSTANCES.get(cache_key)
|
|
||||||
if existing:
|
|
||||||
return existing
|
|
||||||
|
|
||||||
config = backend_config.model_copy()
|
|
||||||
config.namespace = reference.namespace
|
|
||||||
|
|
||||||
if config.type == StorageBackendType.KV_REDIS.value:
|
|
||||||
from .redis import RedisKVStoreImpl
|
|
||||||
|
|
||||||
impl = RedisKVStoreImpl(config)
|
|
||||||
elif config.type == StorageBackendType.KV_SQLITE.value:
|
|
||||||
from .sqlite import SqliteKVStoreImpl
|
|
||||||
|
|
||||||
impl = SqliteKVStoreImpl(config)
|
|
||||||
elif config.type == StorageBackendType.KV_POSTGRES.value:
|
|
||||||
from .postgres import PostgresKVStoreImpl
|
|
||||||
|
|
||||||
impl = PostgresKVStoreImpl(config)
|
|
||||||
elif config.type == StorageBackendType.KV_MONGODB.value:
|
|
||||||
from .mongodb import MongoDBKVStoreImpl
|
|
||||||
|
|
||||||
impl = MongoDBKVStoreImpl(config)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unknown kvstore type {config.type}")
|
|
||||||
|
|
||||||
await impl.initialize()
|
|
||||||
_KVSTORE_INSTANCES[cache_key] = impl
|
|
||||||
return impl
|
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,4 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from .mongodb import MongoDBKVStoreImpl
|
from llama_stack.core.storage.kvstore.mongodb import * # noqa: F401,F403
|
||||||
|
|
||||||
__all__ = ["MongoDBKVStoreImpl"]
|
|
||||||
|
|
|
||||||
|
|
@ -4,82 +4,4 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from datetime import datetime
|
from llama_stack.core.storage.kvstore.mongodb.mongodb import * # noqa: F401,F403
|
||||||
|
|
||||||
from pymongo import AsyncMongoClient
|
|
||||||
from pymongo.asynchronous.collection import AsyncCollection
|
|
||||||
|
|
||||||
from llama_stack.log import get_logger
|
|
||||||
from llama_stack.providers.utils.kvstore import KVStore
|
|
||||||
|
|
||||||
from ..config import MongoDBKVStoreConfig
|
|
||||||
|
|
||||||
log = get_logger(name=__name__, category="providers::utils")
|
|
||||||
|
|
||||||
|
|
||||||
class MongoDBKVStoreImpl(KVStore):
|
|
||||||
def __init__(self, config: MongoDBKVStoreConfig):
|
|
||||||
self.config = config
|
|
||||||
self.conn: AsyncMongoClient | None = None
|
|
||||||
|
|
||||||
@property
|
|
||||||
def collection(self) -> AsyncCollection:
|
|
||||||
if self.conn is None:
|
|
||||||
raise RuntimeError("MongoDB connection is not initialized")
|
|
||||||
return self.conn[self.config.db][self.config.collection_name]
|
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
|
||||||
try:
|
|
||||||
# Pass parameters explicitly to satisfy mypy - AsyncMongoClient doesn't accept **dict
|
|
||||||
self.conn = AsyncMongoClient(
|
|
||||||
host=self.config.host if self.config.host is not None else None,
|
|
||||||
port=self.config.port if self.config.port is not None else None,
|
|
||||||
username=self.config.user if self.config.user is not None else None,
|
|
||||||
password=self.config.password if self.config.password is not None else None,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
log.exception("Could not connect to MongoDB database server")
|
|
||||||
raise RuntimeError("Could not connect to MongoDB database server") from e
|
|
||||||
|
|
||||||
def _namespaced_key(self, key: str) -> str:
|
|
||||||
if not self.config.namespace:
|
|
||||||
return key
|
|
||||||
return f"{self.config.namespace}:{key}"
|
|
||||||
|
|
||||||
async def set(self, key: str, value: str, expiration: datetime | None = None) -> None:
|
|
||||||
key = self._namespaced_key(key)
|
|
||||||
update_query = {"$set": {"value": value, "expiration": expiration}}
|
|
||||||
await self.collection.update_one({"key": key}, update_query, upsert=True)
|
|
||||||
|
|
||||||
async def get(self, key: str) -> str | None:
|
|
||||||
key = self._namespaced_key(key)
|
|
||||||
query = {"key": key}
|
|
||||||
result = await self.collection.find_one(query, {"value": 1, "_id": 0})
|
|
||||||
return result["value"] if result else None
|
|
||||||
|
|
||||||
async def delete(self, key: str) -> None:
|
|
||||||
key = self._namespaced_key(key)
|
|
||||||
await self.collection.delete_one({"key": key})
|
|
||||||
|
|
||||||
async def values_in_range(self, start_key: str, end_key: str) -> list[str]:
|
|
||||||
start_key = self._namespaced_key(start_key)
|
|
||||||
end_key = self._namespaced_key(end_key)
|
|
||||||
query = {
|
|
||||||
"key": {"$gte": start_key, "$lt": end_key},
|
|
||||||
}
|
|
||||||
cursor = self.collection.find(query, {"value": 1, "_id": 0}).sort("key", 1)
|
|
||||||
result = []
|
|
||||||
async for doc in cursor:
|
|
||||||
result.append(doc["value"])
|
|
||||||
return result
|
|
||||||
|
|
||||||
async def keys_in_range(self, start_key: str, end_key: str) -> list[str]:
|
|
||||||
start_key = self._namespaced_key(start_key)
|
|
||||||
end_key = self._namespaced_key(end_key)
|
|
||||||
query = {"key": {"$gte": start_key, "$lt": end_key}}
|
|
||||||
cursor = self.collection.find(query, {"key": 1, "_id": 0}).sort("key", 1)
|
|
||||||
# AsyncCursor requires async iteration
|
|
||||||
result = []
|
|
||||||
async for doc in cursor:
|
|
||||||
result.append(doc["key"])
|
|
||||||
return result
|
|
||||||
|
|
|
||||||
|
|
@ -4,4 +4,4 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from .postgres import PostgresKVStoreImpl # noqa: F401 F403
|
from llama_stack.core.storage.kvstore.postgres import * # noqa: F401,F403
|
||||||
|
|
|
||||||
|
|
@ -4,111 +4,4 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from datetime import datetime
|
from llama_stack.core.storage.kvstore.postgres.postgres import * # noqa: F401,F403
|
||||||
|
|
||||||
import psycopg2
|
|
||||||
from psycopg2.extras import DictCursor
|
|
||||||
|
|
||||||
from llama_stack.log import get_logger
|
|
||||||
|
|
||||||
from ..api import KVStore
|
|
||||||
from ..config import PostgresKVStoreConfig
|
|
||||||
|
|
||||||
log = get_logger(name=__name__, category="providers::utils")
|
|
||||||
|
|
||||||
|
|
||||||
class PostgresKVStoreImpl(KVStore):
|
|
||||||
def __init__(self, config: PostgresKVStoreConfig):
|
|
||||||
self.config = config
|
|
||||||
self.conn = None
|
|
||||||
self.cursor = None
|
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
|
||||||
try:
|
|
||||||
self.conn = psycopg2.connect(
|
|
||||||
host=self.config.host,
|
|
||||||
port=self.config.port,
|
|
||||||
database=self.config.db,
|
|
||||||
user=self.config.user,
|
|
||||||
password=self.config.password,
|
|
||||||
sslmode=self.config.ssl_mode,
|
|
||||||
sslrootcert=self.config.ca_cert_path,
|
|
||||||
)
|
|
||||||
self.conn.autocommit = True
|
|
||||||
self.cursor = self.conn.cursor(cursor_factory=DictCursor)
|
|
||||||
|
|
||||||
# Create table if it doesn't exist
|
|
||||||
self.cursor.execute(
|
|
||||||
f"""
|
|
||||||
CREATE TABLE IF NOT EXISTS {self.config.table_name} (
|
|
||||||
key TEXT PRIMARY KEY,
|
|
||||||
value TEXT,
|
|
||||||
expiration TIMESTAMP
|
|
||||||
)
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
log.exception("Could not connect to PostgreSQL database server")
|
|
||||||
raise RuntimeError("Could not connect to PostgreSQL database server") from e
|
|
||||||
|
|
||||||
def _namespaced_key(self, key: str) -> str:
|
|
||||||
if not self.config.namespace:
|
|
||||||
return key
|
|
||||||
return f"{self.config.namespace}:{key}"
|
|
||||||
|
|
||||||
async def set(self, key: str, value: str, expiration: datetime | None = None) -> None:
|
|
||||||
key = self._namespaced_key(key)
|
|
||||||
self.cursor.execute(
|
|
||||||
f"""
|
|
||||||
INSERT INTO {self.config.table_name} (key, value, expiration)
|
|
||||||
VALUES (%s, %s, %s)
|
|
||||||
ON CONFLICT (key) DO UPDATE
|
|
||||||
SET value = EXCLUDED.value, expiration = EXCLUDED.expiration
|
|
||||||
""",
|
|
||||||
(key, value, expiration),
|
|
||||||
)
|
|
||||||
|
|
||||||
async def get(self, key: str) -> str | None:
|
|
||||||
key = self._namespaced_key(key)
|
|
||||||
self.cursor.execute(
|
|
||||||
f"""
|
|
||||||
SELECT value FROM {self.config.table_name}
|
|
||||||
WHERE key = %s
|
|
||||||
AND (expiration IS NULL OR expiration > NOW())
|
|
||||||
""",
|
|
||||||
(key,),
|
|
||||||
)
|
|
||||||
result = self.cursor.fetchone()
|
|
||||||
return result[0] if result else None
|
|
||||||
|
|
||||||
async def delete(self, key: str) -> None:
|
|
||||||
key = self._namespaced_key(key)
|
|
||||||
self.cursor.execute(
|
|
||||||
f"DELETE FROM {self.config.table_name} WHERE key = %s",
|
|
||||||
(key,),
|
|
||||||
)
|
|
||||||
|
|
||||||
async def values_in_range(self, start_key: str, end_key: str) -> list[str]:
|
|
||||||
start_key = self._namespaced_key(start_key)
|
|
||||||
end_key = self._namespaced_key(end_key)
|
|
||||||
|
|
||||||
self.cursor.execute(
|
|
||||||
f"""
|
|
||||||
SELECT value FROM {self.config.table_name}
|
|
||||||
WHERE key >= %s AND key < %s
|
|
||||||
AND (expiration IS NULL OR expiration > NOW())
|
|
||||||
ORDER BY key
|
|
||||||
""",
|
|
||||||
(start_key, end_key),
|
|
||||||
)
|
|
||||||
return [row[0] for row in self.cursor.fetchall()]
|
|
||||||
|
|
||||||
async def keys_in_range(self, start_key: str, end_key: str) -> list[str]:
|
|
||||||
start_key = self._namespaced_key(start_key)
|
|
||||||
end_key = self._namespaced_key(end_key)
|
|
||||||
|
|
||||||
self.cursor.execute(
|
|
||||||
f"SELECT key FROM {self.config.table_name} WHERE key >= %s AND key < %s",
|
|
||||||
(start_key, end_key),
|
|
||||||
)
|
|
||||||
return [row[0] for row in self.cursor.fetchall()]
|
|
||||||
|
|
|
||||||
|
|
@ -4,4 +4,4 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from .redis import RedisKVStoreImpl # noqa: F401
|
from llama_stack.core.storage.kvstore.redis import * # noqa: F401,F403
|
||||||
|
|
|
||||||
|
|
@ -4,73 +4,4 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from datetime import datetime
|
from llama_stack.core.storage.kvstore.redis.redis import * # noqa: F401,F403
|
||||||
|
|
||||||
from redis.asyncio import Redis
|
|
||||||
|
|
||||||
from ..api import KVStore
|
|
||||||
from ..config import RedisKVStoreConfig
|
|
||||||
|
|
||||||
|
|
||||||
class RedisKVStoreImpl(KVStore):
|
|
||||||
def __init__(self, config: RedisKVStoreConfig):
|
|
||||||
self.config = config
|
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
|
||||||
self.redis = Redis.from_url(self.config.url)
|
|
||||||
|
|
||||||
def _namespaced_key(self, key: str) -> str:
|
|
||||||
if not self.config.namespace:
|
|
||||||
return key
|
|
||||||
return f"{self.config.namespace}:{key}"
|
|
||||||
|
|
||||||
async def set(self, key: str, value: str, expiration: datetime | None = None) -> None:
|
|
||||||
key = self._namespaced_key(key)
|
|
||||||
await self.redis.set(key, value)
|
|
||||||
if expiration:
|
|
||||||
await self.redis.expireat(key, expiration)
|
|
||||||
|
|
||||||
async def get(self, key: str) -> str | None:
|
|
||||||
key = self._namespaced_key(key)
|
|
||||||
value = await self.redis.get(key)
|
|
||||||
if value is None:
|
|
||||||
return None
|
|
||||||
await self.redis.ttl(key)
|
|
||||||
return value
|
|
||||||
|
|
||||||
async def delete(self, key: str) -> None:
|
|
||||||
key = self._namespaced_key(key)
|
|
||||||
await self.redis.delete(key)
|
|
||||||
|
|
||||||
async def values_in_range(self, start_key: str, end_key: str) -> list[str]:
|
|
||||||
start_key = self._namespaced_key(start_key)
|
|
||||||
end_key = self._namespaced_key(end_key)
|
|
||||||
cursor = 0
|
|
||||||
pattern = start_key + "*" # Match all keys starting with start_key prefix
|
|
||||||
matching_keys = []
|
|
||||||
while True:
|
|
||||||
cursor, keys = await self.redis.scan(cursor, match=pattern, count=1000)
|
|
||||||
|
|
||||||
for key in keys:
|
|
||||||
key_str = key.decode("utf-8") if isinstance(key, bytes) else key
|
|
||||||
if start_key <= key_str <= end_key:
|
|
||||||
matching_keys.append(key)
|
|
||||||
|
|
||||||
if cursor == 0:
|
|
||||||
break
|
|
||||||
|
|
||||||
# Then fetch all values in a single MGET call
|
|
||||||
if matching_keys:
|
|
||||||
values = await self.redis.mget(matching_keys)
|
|
||||||
return [
|
|
||||||
value.decode("utf-8") if isinstance(value, bytes) else value for value in values if value is not None
|
|
||||||
]
|
|
||||||
|
|
||||||
return []
|
|
||||||
|
|
||||||
async def keys_in_range(self, start_key: str, end_key: str) -> list[str]:
|
|
||||||
"""Get all keys in the given range."""
|
|
||||||
matching_keys = await self.redis.zrangebylex(self.namespace, f"[{start_key}", f"[{end_key}")
|
|
||||||
if not matching_keys:
|
|
||||||
return []
|
|
||||||
return [k.decode("utf-8") for k in matching_keys]
|
|
||||||
|
|
|
||||||
|
|
@ -4,4 +4,4 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from .sqlite import SqliteKVStoreImpl # noqa: F401
|
from llama_stack.core.storage.kvstore.sqlite import * # noqa: F401,F403
|
||||||
|
|
|
||||||
|
|
@ -4,17 +4,4 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from llama_stack.core.storage.kvstore.sqlite.config import * # noqa: F401,F403
|
||||||
|
|
||||||
from llama_stack_api import json_schema_type
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class SqliteControlPlaneConfig(BaseModel):
|
|
||||||
db_path: str = Field(
|
|
||||||
description="File path for the sqlite database",
|
|
||||||
)
|
|
||||||
table_name: str = Field(
|
|
||||||
default="llamastack_control_plane",
|
|
||||||
description="Table into which all the keys will be placed",
|
|
||||||
)
|
|
||||||
|
|
|
||||||
|
|
@ -4,171 +4,4 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import os
|
from llama_stack.core.storage.kvstore.sqlite.sqlite import * # noqa: F401,F403
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
import aiosqlite
|
|
||||||
|
|
||||||
from llama_stack.log import get_logger
|
|
||||||
|
|
||||||
from ..api import KVStore
|
|
||||||
from ..config import SqliteKVStoreConfig
|
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="providers::utils")
|
|
||||||
|
|
||||||
|
|
||||||
class SqliteKVStoreImpl(KVStore):
|
|
||||||
def __init__(self, config: SqliteKVStoreConfig):
|
|
||||||
self.db_path = config.db_path
|
|
||||||
self.table_name = "kvstore"
|
|
||||||
self._conn: aiosqlite.Connection | None = None
|
|
||||||
|
|
||||||
def __str__(self):
|
|
||||||
return f"SqliteKVStoreImpl(db_path={self.db_path}, table_name={self.table_name})"
|
|
||||||
|
|
||||||
def _is_memory_db(self) -> bool:
|
|
||||||
"""Check if this is an in-memory database."""
|
|
||||||
return self.db_path == ":memory:" or "mode=memory" in self.db_path
|
|
||||||
|
|
||||||
async def initialize(self):
|
|
||||||
# Skip directory creation for in-memory databases and file: URIs
|
|
||||||
if not self._is_memory_db() and not self.db_path.startswith("file:"):
|
|
||||||
db_dir = os.path.dirname(self.db_path)
|
|
||||||
if db_dir: # Only create if there's a directory component
|
|
||||||
os.makedirs(db_dir, exist_ok=True)
|
|
||||||
|
|
||||||
# Only use persistent connection for in-memory databases
|
|
||||||
# File-based databases use connection-per-operation to avoid hangs
|
|
||||||
if self._is_memory_db():
|
|
||||||
self._conn = await aiosqlite.connect(self.db_path)
|
|
||||||
await self._conn.execute(
|
|
||||||
f"""
|
|
||||||
CREATE TABLE IF NOT EXISTS {self.table_name} (
|
|
||||||
key TEXT PRIMARY KEY,
|
|
||||||
value TEXT,
|
|
||||||
expiration TIMESTAMP
|
|
||||||
)
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
await self._conn.commit()
|
|
||||||
else:
|
|
||||||
# For file-based databases, just create the table
|
|
||||||
async with aiosqlite.connect(self.db_path) as db:
|
|
||||||
await db.execute(
|
|
||||||
f"""
|
|
||||||
CREATE TABLE IF NOT EXISTS {self.table_name} (
|
|
||||||
key TEXT PRIMARY KEY,
|
|
||||||
value TEXT,
|
|
||||||
expiration TIMESTAMP
|
|
||||||
)
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
await db.commit()
|
|
||||||
|
|
||||||
async def shutdown(self):
|
|
||||||
"""Close the persistent connection (only for in-memory databases)."""
|
|
||||||
if self._conn:
|
|
||||||
await self._conn.close()
|
|
||||||
self._conn = None
|
|
||||||
|
|
||||||
async def set(self, key: str, value: str, expiration: datetime | None = None) -> None:
|
|
||||||
if self._conn:
|
|
||||||
# In-memory database with persistent connection
|
|
||||||
await self._conn.execute(
|
|
||||||
f"INSERT OR REPLACE INTO {self.table_name} (key, value, expiration) VALUES (?, ?, ?)",
|
|
||||||
(key, value, expiration),
|
|
||||||
)
|
|
||||||
await self._conn.commit()
|
|
||||||
else:
|
|
||||||
# File-based database with connection per operation
|
|
||||||
async with aiosqlite.connect(self.db_path) as db:
|
|
||||||
await db.execute(
|
|
||||||
f"INSERT OR REPLACE INTO {self.table_name} (key, value, expiration) VALUES (?, ?, ?)",
|
|
||||||
(key, value, expiration),
|
|
||||||
)
|
|
||||||
await db.commit()
|
|
||||||
|
|
||||||
async def get(self, key: str) -> str | None:
|
|
||||||
if self._conn:
|
|
||||||
# In-memory database with persistent connection
|
|
||||||
async with self._conn.execute(
|
|
||||||
f"SELECT value, expiration FROM {self.table_name} WHERE key = ?", (key,)
|
|
||||||
) as cursor:
|
|
||||||
row = await cursor.fetchone()
|
|
||||||
if row is None:
|
|
||||||
return None
|
|
||||||
value, expiration = row
|
|
||||||
if not isinstance(value, str):
|
|
||||||
logger.warning(f"Expected string value for key {key}, got {type(value)}, returning None")
|
|
||||||
return None
|
|
||||||
return value
|
|
||||||
else:
|
|
||||||
# File-based database with connection per operation
|
|
||||||
async with aiosqlite.connect(self.db_path) as db:
|
|
||||||
async with db.execute(
|
|
||||||
f"SELECT value, expiration FROM {self.table_name} WHERE key = ?", (key,)
|
|
||||||
) as cursor:
|
|
||||||
row = await cursor.fetchone()
|
|
||||||
if row is None:
|
|
||||||
return None
|
|
||||||
value, expiration = row
|
|
||||||
if not isinstance(value, str):
|
|
||||||
logger.warning(f"Expected string value for key {key}, got {type(value)}, returning None")
|
|
||||||
return None
|
|
||||||
return value
|
|
||||||
|
|
||||||
async def delete(self, key: str) -> None:
|
|
||||||
if self._conn:
|
|
||||||
# In-memory database with persistent connection
|
|
||||||
await self._conn.execute(f"DELETE FROM {self.table_name} WHERE key = ?", (key,))
|
|
||||||
await self._conn.commit()
|
|
||||||
else:
|
|
||||||
# File-based database with connection per operation
|
|
||||||
async with aiosqlite.connect(self.db_path) as db:
|
|
||||||
await db.execute(f"DELETE FROM {self.table_name} WHERE key = ?", (key,))
|
|
||||||
await db.commit()
|
|
||||||
|
|
||||||
async def values_in_range(self, start_key: str, end_key: str) -> list[str]:
|
|
||||||
if self._conn:
|
|
||||||
# In-memory database with persistent connection
|
|
||||||
async with self._conn.execute(
|
|
||||||
f"SELECT key, value, expiration FROM {self.table_name} WHERE key >= ? AND key <= ?",
|
|
||||||
(start_key, end_key),
|
|
||||||
) as cursor:
|
|
||||||
result = []
|
|
||||||
async for row in cursor:
|
|
||||||
_, value, _ = row
|
|
||||||
result.append(value)
|
|
||||||
return result
|
|
||||||
else:
|
|
||||||
# File-based database with connection per operation
|
|
||||||
async with aiosqlite.connect(self.db_path) as db:
|
|
||||||
async with db.execute(
|
|
||||||
f"SELECT key, value, expiration FROM {self.table_name} WHERE key >= ? AND key <= ?",
|
|
||||||
(start_key, end_key),
|
|
||||||
) as cursor:
|
|
||||||
result = []
|
|
||||||
async for row in cursor:
|
|
||||||
_, value, _ = row
|
|
||||||
result.append(value)
|
|
||||||
return result
|
|
||||||
|
|
||||||
async def keys_in_range(self, start_key: str, end_key: str) -> list[str]:
|
|
||||||
"""Get all keys in the given range."""
|
|
||||||
if self._conn:
|
|
||||||
# In-memory database with persistent connection
|
|
||||||
cursor = await self._conn.execute(
|
|
||||||
f"SELECT key FROM {self.table_name} WHERE key >= ? AND key <= ?",
|
|
||||||
(start_key, end_key),
|
|
||||||
)
|
|
||||||
rows = await cursor.fetchall()
|
|
||||||
return [row[0] for row in rows]
|
|
||||||
else:
|
|
||||||
# File-based database with connection per operation
|
|
||||||
async with aiosqlite.connect(self.db_path) as db:
|
|
||||||
cursor = await db.execute(
|
|
||||||
f"SELECT key FROM {self.table_name} WHERE key >= ? AND key <= ?",
|
|
||||||
(start_key, end_key),
|
|
||||||
)
|
|
||||||
rows = await cursor.fetchall()
|
|
||||||
return [row[0] for row in rows]
|
|
||||||
|
|
|
||||||
|
|
@ -17,7 +17,6 @@ from pydantic import TypeAdapter
|
||||||
|
|
||||||
from llama_stack.core.id_generation import generate_object_id
|
from llama_stack.core.id_generation import generate_object_id
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.utils.kvstore.api import KVStore
|
|
||||||
from llama_stack.providers.utils.memory.vector_store import (
|
from llama_stack.providers.utils.memory.vector_store import (
|
||||||
ChunkForDeletion,
|
ChunkForDeletion,
|
||||||
content_from_data_and_mime_type,
|
content_from_data_and_mime_type,
|
||||||
|
|
@ -53,6 +52,7 @@ from llama_stack_api import (
|
||||||
VectorStoreSearchResponse,
|
VectorStoreSearchResponse,
|
||||||
VectorStoreSearchResponsePage,
|
VectorStoreSearchResponsePage,
|
||||||
)
|
)
|
||||||
|
from llama_stack_api.internal.kvstore import KVStore
|
||||||
|
|
||||||
EMBEDDING_DIMENSION = 768
|
EMBEDDING_DIMENSION = 768
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -18,9 +18,9 @@ from llama_stack_api import (
|
||||||
Order,
|
Order,
|
||||||
)
|
)
|
||||||
|
|
||||||
from ..sqlstore.api import ColumnDefinition, ColumnType
|
from llama_stack.core.storage.sqlstore.api import ColumnDefinition, ColumnType
|
||||||
from ..sqlstore.authorized_sqlstore import AuthorizedSqlStore
|
from llama_stack.core.storage.sqlstore.authorized_sqlstore import AuthorizedSqlStore
|
||||||
from ..sqlstore.sqlstore import sqlstore_impl
|
from llama_stack.core.storage.sqlstore.sqlstore import sqlstore_impl
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="openai_responses")
|
logger = get_logger(name=__name__, category="openai_responses")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -3,3 +3,5 @@
|
||||||
#
|
#
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from llama_stack.core.storage.sqlstore import * # noqa: F401,F403
|
||||||
|
|
|
||||||
|
|
@ -4,137 +4,4 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from collections.abc import Mapping, Sequence
|
from llama_stack.core.storage.sqlstore.api import * # noqa: F401,F403
|
||||||
from enum import Enum
|
|
||||||
from typing import Any, Literal, Protocol
|
|
||||||
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from llama_stack_api import PaginatedResponse
|
|
||||||
|
|
||||||
|
|
||||||
class ColumnType(Enum):
|
|
||||||
INTEGER = "INTEGER"
|
|
||||||
STRING = "STRING"
|
|
||||||
TEXT = "TEXT"
|
|
||||||
FLOAT = "FLOAT"
|
|
||||||
BOOLEAN = "BOOLEAN"
|
|
||||||
JSON = "JSON"
|
|
||||||
DATETIME = "DATETIME"
|
|
||||||
|
|
||||||
|
|
||||||
class ColumnDefinition(BaseModel):
|
|
||||||
type: ColumnType
|
|
||||||
primary_key: bool = False
|
|
||||||
nullable: bool = True
|
|
||||||
default: Any = None
|
|
||||||
|
|
||||||
|
|
||||||
class SqlStore(Protocol):
|
|
||||||
"""
|
|
||||||
A protocol for a SQL store.
|
|
||||||
"""
|
|
||||||
|
|
||||||
async def create_table(self, table: str, schema: Mapping[str, ColumnType | ColumnDefinition]) -> None:
|
|
||||||
"""
|
|
||||||
Create a table.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def insert(self, table: str, data: Mapping[str, Any] | Sequence[Mapping[str, Any]]) -> None:
|
|
||||||
"""
|
|
||||||
Insert a row or batch of rows into a table.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def upsert(
|
|
||||||
self,
|
|
||||||
table: str,
|
|
||||||
data: Mapping[str, Any],
|
|
||||||
conflict_columns: list[str],
|
|
||||||
update_columns: list[str] | None = None,
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
Insert a row and update specified columns when conflicts occur.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def fetch_all(
|
|
||||||
self,
|
|
||||||
table: str,
|
|
||||||
where: Mapping[str, Any] | None = None,
|
|
||||||
where_sql: str | None = None,
|
|
||||||
limit: int | None = None,
|
|
||||||
order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None,
|
|
||||||
cursor: tuple[str, str] | None = None,
|
|
||||||
) -> PaginatedResponse:
|
|
||||||
"""
|
|
||||||
Fetch all rows from a table with optional cursor-based pagination.
|
|
||||||
|
|
||||||
:param table: The table name
|
|
||||||
:param where: Simple key-value WHERE conditions
|
|
||||||
:param where_sql: Raw SQL WHERE clause for complex queries
|
|
||||||
:param limit: Maximum number of records to return
|
|
||||||
:param order_by: List of (column, order) tuples for sorting
|
|
||||||
:param cursor: Tuple of (key_column, cursor_id) for pagination (None for first page)
|
|
||||||
Requires order_by with exactly one column when used
|
|
||||||
:return: PaginatedResult with data and has_more flag
|
|
||||||
|
|
||||||
Note: Cursor pagination only supports single-column ordering for simplicity.
|
|
||||||
Multi-column ordering is allowed without cursor but will raise an error with cursor.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def fetch_one(
|
|
||||||
self,
|
|
||||||
table: str,
|
|
||||||
where: Mapping[str, Any] | None = None,
|
|
||||||
where_sql: str | None = None,
|
|
||||||
order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None,
|
|
||||||
) -> dict[str, Any] | None:
|
|
||||||
"""
|
|
||||||
Fetch one row from a table.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def update(
|
|
||||||
self,
|
|
||||||
table: str,
|
|
||||||
data: Mapping[str, Any],
|
|
||||||
where: Mapping[str, Any],
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
Update a row in a table.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def delete(
|
|
||||||
self,
|
|
||||||
table: str,
|
|
||||||
where: Mapping[str, Any],
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
Delete a row from a table.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def add_column_if_not_exists(
|
|
||||||
self,
|
|
||||||
table: str,
|
|
||||||
column_name: str,
|
|
||||||
column_type: ColumnType,
|
|
||||||
nullable: bool = True,
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
Add a column to an existing table if the column doesn't already exist.
|
|
||||||
|
|
||||||
This is useful for table migrations when adding new functionality.
|
|
||||||
If the table doesn't exist, this method should do nothing.
|
|
||||||
If the column already exists, this method should do nothing.
|
|
||||||
|
|
||||||
:param table: Table name
|
|
||||||
:param column_name: Name of the column to add
|
|
||||||
:param column_type: Type of the column to add
|
|
||||||
:param nullable: Whether the column should be nullable (default: True)
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
|
||||||
|
|
@ -4,333 +4,4 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from collections.abc import Mapping, Sequence
|
from llama_stack.core.storage.sqlstore.authorized_sqlstore import * # noqa: F401,F403
|
||||||
from typing import Any, Literal
|
|
||||||
|
|
||||||
from llama_stack.core.access_control.access_control import default_policy, is_action_allowed
|
|
||||||
from llama_stack.core.access_control.conditions import ProtectedResource
|
|
||||||
from llama_stack.core.access_control.datatypes import AccessRule, Action, Scope
|
|
||||||
from llama_stack.core.datatypes import User
|
|
||||||
from llama_stack.core.request_headers import get_authenticated_user
|
|
||||||
from llama_stack.core.storage.datatypes import StorageBackendType
|
|
||||||
from llama_stack.log import get_logger
|
|
||||||
|
|
||||||
from .api import ColumnDefinition, ColumnType, PaginatedResponse, SqlStore
|
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="providers::utils")
|
|
||||||
|
|
||||||
# Hardcoded copy of the default policy that our SQL filtering implements
|
|
||||||
# WARNING: If default_policy() changes, this constant must be updated accordingly
|
|
||||||
# or SQL filtering will fall back to conservative mode (safe but less performant)
|
|
||||||
#
|
|
||||||
# This policy represents: "Permit all actions when user is in owners list for ALL attribute categories"
|
|
||||||
# The corresponding SQL logic is implemented in _build_default_policy_where_clause():
|
|
||||||
# - Public records (no access_attributes) are always accessible
|
|
||||||
# - Records with access_attributes require user to match ALL categories that exist in the resource
|
|
||||||
# - Missing categories in the resource are treated as "no restriction" (allow)
|
|
||||||
# - Within each category, user needs ANY matching value (OR logic)
|
|
||||||
# - Between categories, user needs ALL categories to match (AND logic)
|
|
||||||
SQL_OPTIMIZED_POLICY = [
|
|
||||||
AccessRule(
|
|
||||||
permit=Scope(actions=list(Action)),
|
|
||||||
when=["user in owners roles", "user in owners teams", "user in owners projects", "user in owners namespaces"],
|
|
||||||
),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def _enhance_item_with_access_control(item: Mapping[str, Any], current_user: User | None) -> Mapping[str, Any]:
|
|
||||||
"""Add access control attributes to a data item."""
|
|
||||||
enhanced = dict(item)
|
|
||||||
if current_user:
|
|
||||||
enhanced["owner_principal"] = current_user.principal
|
|
||||||
enhanced["access_attributes"] = current_user.attributes
|
|
||||||
else:
|
|
||||||
# IMPORTANT: Use empty string and null value (not None) to match public access filter
|
|
||||||
# The public access filter in _get_public_access_conditions() expects:
|
|
||||||
# - owner_principal = '' (empty string)
|
|
||||||
# - access_attributes = null (JSON null, which serializes to the string 'null')
|
|
||||||
# Setting them to None (SQL NULL) will cause rows to be filtered out on read.
|
|
||||||
enhanced["owner_principal"] = ""
|
|
||||||
enhanced["access_attributes"] = None # Pydantic/JSON will serialize this as JSON null
|
|
||||||
return enhanced
|
|
||||||
|
|
||||||
|
|
||||||
class SqlRecord(ProtectedResource):
|
|
||||||
def __init__(self, record_id: str, table_name: str, owner: User):
|
|
||||||
self.type = f"sql_record::{table_name}"
|
|
||||||
self.identifier = record_id
|
|
||||||
self.owner = owner
|
|
||||||
|
|
||||||
|
|
||||||
class AuthorizedSqlStore:
|
|
||||||
"""
|
|
||||||
Authorization layer for SqlStore that provides access control functionality.
|
|
||||||
|
|
||||||
This class composes a base SqlStore and adds authorization methods that handle
|
|
||||||
access control policies, user attribute capture, and SQL filtering optimization.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, sql_store: SqlStore, policy: list[AccessRule]):
|
|
||||||
"""
|
|
||||||
Initialize the authorization layer.
|
|
||||||
|
|
||||||
:param sql_store: Base SqlStore implementation to wrap
|
|
||||||
:param policy: Access control policy to use for authorization
|
|
||||||
"""
|
|
||||||
self.sql_store = sql_store
|
|
||||||
self.policy = policy
|
|
||||||
self._detect_database_type()
|
|
||||||
self._validate_sql_optimized_policy()
|
|
||||||
|
|
||||||
def _detect_database_type(self) -> None:
|
|
||||||
"""Detect the database type from the underlying SQL store."""
|
|
||||||
if not hasattr(self.sql_store, "config"):
|
|
||||||
raise ValueError("SqlStore must have a config attribute to be used with AuthorizedSqlStore")
|
|
||||||
|
|
||||||
self.database_type = self.sql_store.config.type.value
|
|
||||||
if self.database_type not in [StorageBackendType.SQL_POSTGRES.value, StorageBackendType.SQL_SQLITE.value]:
|
|
||||||
raise ValueError(f"Unsupported database type: {self.database_type}")
|
|
||||||
|
|
||||||
def _validate_sql_optimized_policy(self) -> None:
|
|
||||||
"""Validate that SQL_OPTIMIZED_POLICY matches the actual default_policy().
|
|
||||||
|
|
||||||
This ensures that if default_policy() changes, we detect the mismatch and
|
|
||||||
can update our SQL filtering logic accordingly.
|
|
||||||
"""
|
|
||||||
actual_default = default_policy()
|
|
||||||
|
|
||||||
if SQL_OPTIMIZED_POLICY != actual_default:
|
|
||||||
logger.warning(
|
|
||||||
f"SQL_OPTIMIZED_POLICY does not match default_policy(). "
|
|
||||||
f"SQL filtering will use conservative mode. "
|
|
||||||
f"Expected: {SQL_OPTIMIZED_POLICY}, Got: {actual_default}",
|
|
||||||
)
|
|
||||||
|
|
||||||
async def create_table(self, table: str, schema: Mapping[str, ColumnType | ColumnDefinition]) -> None:
|
|
||||||
"""Create a table with built-in access control support."""
|
|
||||||
|
|
||||||
enhanced_schema = dict(schema)
|
|
||||||
if "access_attributes" not in enhanced_schema:
|
|
||||||
enhanced_schema["access_attributes"] = ColumnType.JSON
|
|
||||||
if "owner_principal" not in enhanced_schema:
|
|
||||||
enhanced_schema["owner_principal"] = ColumnType.STRING
|
|
||||||
|
|
||||||
await self.sql_store.create_table(table, enhanced_schema)
|
|
||||||
await self.sql_store.add_column_if_not_exists(table, "access_attributes", ColumnType.JSON)
|
|
||||||
await self.sql_store.add_column_if_not_exists(table, "owner_principal", ColumnType.STRING)
|
|
||||||
|
|
||||||
async def insert(self, table: str, data: Mapping[str, Any] | Sequence[Mapping[str, Any]]) -> None:
|
|
||||||
"""Insert a row or batch of rows with automatic access control attribute capture."""
|
|
||||||
current_user = get_authenticated_user()
|
|
||||||
enhanced_data: Mapping[str, Any] | Sequence[Mapping[str, Any]]
|
|
||||||
if isinstance(data, Mapping):
|
|
||||||
enhanced_data = _enhance_item_with_access_control(data, current_user)
|
|
||||||
else:
|
|
||||||
enhanced_data = [_enhance_item_with_access_control(item, current_user) for item in data]
|
|
||||||
await self.sql_store.insert(table, enhanced_data)
|
|
||||||
|
|
||||||
async def upsert(
|
|
||||||
self,
|
|
||||||
table: str,
|
|
||||||
data: Mapping[str, Any],
|
|
||||||
conflict_columns: list[str],
|
|
||||||
update_columns: list[str] | None = None,
|
|
||||||
) -> None:
|
|
||||||
"""Upsert a row with automatic access control attribute capture."""
|
|
||||||
current_user = get_authenticated_user()
|
|
||||||
enhanced_data = _enhance_item_with_access_control(data, current_user)
|
|
||||||
await self.sql_store.upsert(
|
|
||||||
table=table,
|
|
||||||
data=enhanced_data,
|
|
||||||
conflict_columns=conflict_columns,
|
|
||||||
update_columns=update_columns,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def fetch_all(
|
|
||||||
self,
|
|
||||||
table: str,
|
|
||||||
where: Mapping[str, Any] | None = None,
|
|
||||||
limit: int | None = None,
|
|
||||||
order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None,
|
|
||||||
cursor: tuple[str, str] | None = None,
|
|
||||||
) -> PaginatedResponse:
|
|
||||||
"""Fetch all rows with automatic access control filtering."""
|
|
||||||
access_where = self._build_access_control_where_clause(self.policy)
|
|
||||||
rows = await self.sql_store.fetch_all(
|
|
||||||
table=table,
|
|
||||||
where=where,
|
|
||||||
where_sql=access_where,
|
|
||||||
limit=limit,
|
|
||||||
order_by=order_by,
|
|
||||||
cursor=cursor,
|
|
||||||
)
|
|
||||||
|
|
||||||
current_user = get_authenticated_user()
|
|
||||||
filtered_rows = []
|
|
||||||
|
|
||||||
for row in rows.data:
|
|
||||||
stored_access_attrs = row.get("access_attributes")
|
|
||||||
stored_owner_principal = row.get("owner_principal") or ""
|
|
||||||
|
|
||||||
record_id = row.get("id", "unknown")
|
|
||||||
sql_record = SqlRecord(
|
|
||||||
str(record_id), table, User(principal=stored_owner_principal, attributes=stored_access_attrs)
|
|
||||||
)
|
|
||||||
|
|
||||||
if is_action_allowed(self.policy, Action.READ, sql_record, current_user):
|
|
||||||
filtered_rows.append(row)
|
|
||||||
|
|
||||||
return PaginatedResponse(
|
|
||||||
data=filtered_rows,
|
|
||||||
has_more=rows.has_more,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def fetch_one(
|
|
||||||
self,
|
|
||||||
table: str,
|
|
||||||
where: Mapping[str, Any] | None = None,
|
|
||||||
order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None,
|
|
||||||
) -> dict[str, Any] | None:
|
|
||||||
"""Fetch one row with automatic access control checking."""
|
|
||||||
results = await self.fetch_all(
|
|
||||||
table=table,
|
|
||||||
where=where,
|
|
||||||
limit=1,
|
|
||||||
order_by=order_by,
|
|
||||||
)
|
|
||||||
|
|
||||||
return results.data[0] if results.data else None
|
|
||||||
|
|
||||||
async def update(self, table: str, data: Mapping[str, Any], where: Mapping[str, Any]) -> None:
|
|
||||||
"""Update rows with automatic access control attribute capture."""
|
|
||||||
enhanced_data = dict(data)
|
|
||||||
|
|
||||||
current_user = get_authenticated_user()
|
|
||||||
if current_user:
|
|
||||||
enhanced_data["owner_principal"] = current_user.principal
|
|
||||||
enhanced_data["access_attributes"] = current_user.attributes
|
|
||||||
else:
|
|
||||||
# IMPORTANT: Use empty string for owner_principal to match public access filter
|
|
||||||
enhanced_data["owner_principal"] = ""
|
|
||||||
enhanced_data["access_attributes"] = None # Will serialize as JSON null
|
|
||||||
|
|
||||||
await self.sql_store.update(table, enhanced_data, where)
|
|
||||||
|
|
||||||
async def delete(self, table: str, where: Mapping[str, Any]) -> None:
|
|
||||||
"""Delete rows with automatic access control filtering."""
|
|
||||||
await self.sql_store.delete(table, where)
|
|
||||||
|
|
||||||
def _build_access_control_where_clause(self, policy: list[AccessRule]) -> str:
|
|
||||||
"""Build SQL WHERE clause for access control filtering.
|
|
||||||
|
|
||||||
Only applies SQL filtering for the default policy to ensure correctness.
|
|
||||||
For custom policies, uses conservative filtering to avoid blocking legitimate access.
|
|
||||||
"""
|
|
||||||
current_user = get_authenticated_user()
|
|
||||||
|
|
||||||
if not policy or policy == SQL_OPTIMIZED_POLICY:
|
|
||||||
return self._build_default_policy_where_clause(current_user)
|
|
||||||
else:
|
|
||||||
return self._build_conservative_where_clause()
|
|
||||||
|
|
||||||
def _json_extract(self, column: str, path: str) -> str:
|
|
||||||
"""Extract JSON value (keeping JSON type).
|
|
||||||
|
|
||||||
Args:
|
|
||||||
column: The JSON column name
|
|
||||||
path: The JSON path (e.g., 'roles', 'teams')
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
SQL expression to extract JSON value
|
|
||||||
"""
|
|
||||||
if self.database_type == StorageBackendType.SQL_POSTGRES.value:
|
|
||||||
return f"{column}->'{path}'"
|
|
||||||
elif self.database_type == StorageBackendType.SQL_SQLITE.value:
|
|
||||||
return f"JSON_EXTRACT({column}, '$.{path}')"
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unsupported database type: {self.database_type}")
|
|
||||||
|
|
||||||
def _json_extract_text(self, column: str, path: str) -> str:
|
|
||||||
"""Extract JSON value as text.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
column: The JSON column name
|
|
||||||
path: The JSON path (e.g., 'roles', 'teams')
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
SQL expression to extract JSON value as text
|
|
||||||
"""
|
|
||||||
if self.database_type == StorageBackendType.SQL_POSTGRES.value:
|
|
||||||
return f"{column}->>'{path}'"
|
|
||||||
elif self.database_type == StorageBackendType.SQL_SQLITE.value:
|
|
||||||
return f"JSON_EXTRACT({column}, '$.{path}')"
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unsupported database type: {self.database_type}")
|
|
||||||
|
|
||||||
def _get_public_access_conditions(self) -> list[str]:
|
|
||||||
"""Get the SQL conditions for public access.
|
|
||||||
|
|
||||||
Public records are those with:
|
|
||||||
- owner_principal = '' (empty string)
|
|
||||||
- access_attributes is either SQL NULL or JSON null
|
|
||||||
|
|
||||||
Note: Different databases serialize None differently:
|
|
||||||
- SQLite: None → JSON null (text = 'null')
|
|
||||||
- Postgres: None → SQL NULL (IS NULL)
|
|
||||||
"""
|
|
||||||
conditions = ["owner_principal = ''"]
|
|
||||||
if self.database_type == StorageBackendType.SQL_POSTGRES.value:
|
|
||||||
# Accept both SQL NULL and JSON null for Postgres compatibility
|
|
||||||
# This handles both old rows (SQL NULL) and new rows (JSON null)
|
|
||||||
conditions.append("(access_attributes IS NULL OR access_attributes::text = 'null')")
|
|
||||||
elif self.database_type == StorageBackendType.SQL_SQLITE.value:
|
|
||||||
# SQLite serializes None as JSON null
|
|
||||||
conditions.append("(access_attributes IS NULL OR access_attributes = 'null')")
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unsupported database type: {self.database_type}")
|
|
||||||
return conditions
|
|
||||||
|
|
||||||
def _build_default_policy_where_clause(self, current_user: User | None) -> str:
|
|
||||||
"""Build SQL WHERE clause for the default policy.
|
|
||||||
|
|
||||||
Default policy: permit all actions when user in owners [roles, teams, projects, namespaces]
|
|
||||||
This means user must match ALL attribute categories that exist in the resource.
|
|
||||||
"""
|
|
||||||
base_conditions = self._get_public_access_conditions()
|
|
||||||
user_attr_conditions = []
|
|
||||||
|
|
||||||
if current_user and current_user.attributes:
|
|
||||||
for attr_key, user_values in current_user.attributes.items():
|
|
||||||
if user_values:
|
|
||||||
value_conditions = []
|
|
||||||
for value in user_values:
|
|
||||||
# Check if JSON array contains the value
|
|
||||||
escaped_value = value.replace("'", "''")
|
|
||||||
json_text = self._json_extract_text("access_attributes", attr_key)
|
|
||||||
value_conditions.append(f"({json_text} LIKE '%\"{escaped_value}\"%')")
|
|
||||||
|
|
||||||
if value_conditions:
|
|
||||||
# Check if the category is missing (NULL)
|
|
||||||
category_missing = f"{self._json_extract('access_attributes', attr_key)} IS NULL"
|
|
||||||
user_matches_category = f"({' OR '.join(value_conditions)})"
|
|
||||||
user_attr_conditions.append(f"({category_missing} OR {user_matches_category})")
|
|
||||||
|
|
||||||
if user_attr_conditions:
|
|
||||||
all_requirements_met = f"({' AND '.join(user_attr_conditions)})"
|
|
||||||
base_conditions.append(all_requirements_met)
|
|
||||||
|
|
||||||
return f"({' OR '.join(base_conditions)})"
|
|
||||||
|
|
||||||
def _build_conservative_where_clause(self) -> str:
|
|
||||||
"""Conservative SQL filtering for custom policies.
|
|
||||||
|
|
||||||
Only filters records we're 100% certain would be denied by any reasonable policy.
|
|
||||||
"""
|
|
||||||
current_user = get_authenticated_user()
|
|
||||||
|
|
||||||
if not current_user:
|
|
||||||
# Only allow public records
|
|
||||||
base_conditions = self._get_public_access_conditions()
|
|
||||||
return f"({' OR '.join(base_conditions)})"
|
|
||||||
|
|
||||||
return "1=1"
|
|
||||||
|
|
|
||||||
|
|
@ -3,372 +3,5 @@
|
||||||
#
|
#
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
from collections.abc import Mapping, Sequence
|
|
||||||
from typing import Any, Literal, cast
|
|
||||||
|
|
||||||
from sqlalchemy import (
|
from llama_stack.core.storage.sqlstore.sqlalchemy_sqlstore import * # noqa: F401,F403
|
||||||
JSON,
|
|
||||||
Boolean,
|
|
||||||
Column,
|
|
||||||
DateTime,
|
|
||||||
Float,
|
|
||||||
Integer,
|
|
||||||
MetaData,
|
|
||||||
String,
|
|
||||||
Table,
|
|
||||||
Text,
|
|
||||||
event,
|
|
||||||
inspect,
|
|
||||||
select,
|
|
||||||
text,
|
|
||||||
)
|
|
||||||
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
|
||||||
from sqlalchemy.ext.asyncio.engine import AsyncEngine
|
|
||||||
from sqlalchemy.sql.elements import ColumnElement
|
|
||||||
|
|
||||||
from llama_stack.core.storage.datatypes import SqlAlchemySqlStoreConfig
|
|
||||||
from llama_stack.log import get_logger
|
|
||||||
from llama_stack_api import PaginatedResponse
|
|
||||||
|
|
||||||
from .api import ColumnDefinition, ColumnType, SqlStore
|
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="providers::utils")
|
|
||||||
|
|
||||||
TYPE_MAPPING: dict[ColumnType, Any] = {
|
|
||||||
ColumnType.INTEGER: Integer,
|
|
||||||
ColumnType.STRING: String,
|
|
||||||
ColumnType.FLOAT: Float,
|
|
||||||
ColumnType.BOOLEAN: Boolean,
|
|
||||||
ColumnType.DATETIME: DateTime,
|
|
||||||
ColumnType.TEXT: Text,
|
|
||||||
ColumnType.JSON: JSON,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def _build_where_expr(column: ColumnElement, value: Any) -> ColumnElement:
|
|
||||||
"""Return a SQLAlchemy expression for a where condition.
|
|
||||||
|
|
||||||
`value` may be a simple scalar (equality) or a mapping like {">": 123}.
|
|
||||||
The returned expression is a SQLAlchemy ColumnElement usable in query.where(...).
|
|
||||||
"""
|
|
||||||
if isinstance(value, Mapping):
|
|
||||||
if len(value) != 1:
|
|
||||||
raise ValueError(f"Operator mapping must have a single operator, got: {value}")
|
|
||||||
op, operand = next(iter(value.items()))
|
|
||||||
if op == "==" or op == "=":
|
|
||||||
return cast(ColumnElement[Any], column == operand)
|
|
||||||
if op == ">":
|
|
||||||
return cast(ColumnElement[Any], column > operand)
|
|
||||||
if op == "<":
|
|
||||||
return cast(ColumnElement[Any], column < operand)
|
|
||||||
if op == ">=":
|
|
||||||
return cast(ColumnElement[Any], column >= operand)
|
|
||||||
if op == "<=":
|
|
||||||
return cast(ColumnElement[Any], column <= operand)
|
|
||||||
raise ValueError(f"Unsupported operator '{op}' in where mapping")
|
|
||||||
return cast(ColumnElement[Any], column == value)
|
|
||||||
|
|
||||||
|
|
||||||
class SqlAlchemySqlStoreImpl(SqlStore):
|
|
||||||
def __init__(self, config: SqlAlchemySqlStoreConfig):
|
|
||||||
self.config = config
|
|
||||||
self._is_sqlite_backend = "sqlite" in self.config.engine_str
|
|
||||||
self.async_session = async_sessionmaker(self.create_engine())
|
|
||||||
self.metadata = MetaData()
|
|
||||||
|
|
||||||
def create_engine(self) -> AsyncEngine:
|
|
||||||
# Configure connection args for better concurrency support
|
|
||||||
connect_args = {}
|
|
||||||
if self._is_sqlite_backend:
|
|
||||||
# SQLite-specific optimizations for concurrent access
|
|
||||||
# With WAL mode, most locks resolve in milliseconds, but allow up to 5s for edge cases
|
|
||||||
connect_args["timeout"] = 5.0
|
|
||||||
connect_args["check_same_thread"] = False # Allow usage across asyncio tasks
|
|
||||||
|
|
||||||
engine = create_async_engine(
|
|
||||||
self.config.engine_str,
|
|
||||||
pool_pre_ping=True,
|
|
||||||
connect_args=connect_args,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Enable WAL mode for SQLite to support concurrent readers and writers
|
|
||||||
if self._is_sqlite_backend:
|
|
||||||
|
|
||||||
@event.listens_for(engine.sync_engine, "connect")
|
|
||||||
def set_sqlite_pragma(dbapi_conn, connection_record):
|
|
||||||
cursor = dbapi_conn.cursor()
|
|
||||||
# Enable Write-Ahead Logging for better concurrency
|
|
||||||
cursor.execute("PRAGMA journal_mode=WAL")
|
|
||||||
# Set busy timeout to 5 seconds (retry instead of immediate failure)
|
|
||||||
# With WAL mode, locks should be brief; if we hit 5s there's a bigger issue
|
|
||||||
cursor.execute("PRAGMA busy_timeout=5000")
|
|
||||||
# Use NORMAL synchronous mode for better performance (still safe with WAL)
|
|
||||||
cursor.execute("PRAGMA synchronous=NORMAL")
|
|
||||||
cursor.close()
|
|
||||||
|
|
||||||
return engine
|
|
||||||
|
|
||||||
async def create_table(
|
|
||||||
self,
|
|
||||||
table: str,
|
|
||||||
schema: Mapping[str, ColumnType | ColumnDefinition],
|
|
||||||
) -> None:
|
|
||||||
if not schema:
|
|
||||||
raise ValueError(f"No columns defined for table '{table}'.")
|
|
||||||
|
|
||||||
sqlalchemy_columns: list[Column] = []
|
|
||||||
|
|
||||||
for col_name, col_props in schema.items():
|
|
||||||
col_type = None
|
|
||||||
is_primary_key = False
|
|
||||||
is_nullable = True
|
|
||||||
|
|
||||||
if isinstance(col_props, ColumnType):
|
|
||||||
col_type = col_props
|
|
||||||
elif isinstance(col_props, ColumnDefinition):
|
|
||||||
col_type = col_props.type
|
|
||||||
is_primary_key = col_props.primary_key
|
|
||||||
is_nullable = col_props.nullable
|
|
||||||
|
|
||||||
sqlalchemy_type = TYPE_MAPPING.get(col_type)
|
|
||||||
if not sqlalchemy_type:
|
|
||||||
raise ValueError(f"Unsupported column type '{col_type}' for column '{col_name}'.")
|
|
||||||
|
|
||||||
sqlalchemy_columns.append(
|
|
||||||
Column(col_name, sqlalchemy_type, primary_key=is_primary_key, nullable=is_nullable)
|
|
||||||
)
|
|
||||||
|
|
||||||
if table not in self.metadata.tables:
|
|
||||||
sqlalchemy_table = Table(table, self.metadata, *sqlalchemy_columns)
|
|
||||||
else:
|
|
||||||
sqlalchemy_table = self.metadata.tables[table]
|
|
||||||
|
|
||||||
engine = self.create_engine()
|
|
||||||
async with engine.begin() as conn:
|
|
||||||
await conn.run_sync(self.metadata.create_all, tables=[sqlalchemy_table], checkfirst=True)
|
|
||||||
|
|
||||||
async def insert(self, table: str, data: Mapping[str, Any] | Sequence[Mapping[str, Any]]) -> None:
|
|
||||||
async with self.async_session() as session:
|
|
||||||
await session.execute(self.metadata.tables[table].insert(), data)
|
|
||||||
await session.commit()
|
|
||||||
|
|
||||||
async def upsert(
|
|
||||||
self,
|
|
||||||
table: str,
|
|
||||||
data: Mapping[str, Any],
|
|
||||||
conflict_columns: list[str],
|
|
||||||
update_columns: list[str] | None = None,
|
|
||||||
) -> None:
|
|
||||||
table_obj = self.metadata.tables[table]
|
|
||||||
dialect_insert = self._get_dialect_insert(table_obj)
|
|
||||||
insert_stmt = dialect_insert.values(**data)
|
|
||||||
|
|
||||||
if update_columns is None:
|
|
||||||
update_columns = [col for col in data.keys() if col not in conflict_columns]
|
|
||||||
|
|
||||||
update_mapping = {col: getattr(insert_stmt.excluded, col) for col in update_columns}
|
|
||||||
conflict_cols = [table_obj.c[col] for col in conflict_columns]
|
|
||||||
|
|
||||||
stmt = insert_stmt.on_conflict_do_update(index_elements=conflict_cols, set_=update_mapping)
|
|
||||||
|
|
||||||
async with self.async_session() as session:
|
|
||||||
await session.execute(stmt)
|
|
||||||
await session.commit()
|
|
||||||
|
|
||||||
async def fetch_all(
|
|
||||||
self,
|
|
||||||
table: str,
|
|
||||||
where: Mapping[str, Any] | None = None,
|
|
||||||
where_sql: str | None = None,
|
|
||||||
limit: int | None = None,
|
|
||||||
order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None,
|
|
||||||
cursor: tuple[str, str] | None = None,
|
|
||||||
) -> PaginatedResponse:
|
|
||||||
async with self.async_session() as session:
|
|
||||||
table_obj = self.metadata.tables[table]
|
|
||||||
query = select(table_obj)
|
|
||||||
|
|
||||||
if where:
|
|
||||||
for key, value in where.items():
|
|
||||||
query = query.where(_build_where_expr(table_obj.c[key], value))
|
|
||||||
|
|
||||||
if where_sql:
|
|
||||||
query = query.where(text(where_sql))
|
|
||||||
|
|
||||||
# Handle cursor-based pagination
|
|
||||||
if cursor:
|
|
||||||
# Validate cursor tuple format
|
|
||||||
if not isinstance(cursor, tuple) or len(cursor) != 2:
|
|
||||||
raise ValueError(f"Cursor must be a tuple of (key_column, cursor_id), got: {cursor}")
|
|
||||||
|
|
||||||
# Require order_by for cursor pagination
|
|
||||||
if not order_by:
|
|
||||||
raise ValueError("order_by is required when using cursor pagination")
|
|
||||||
|
|
||||||
# Only support single-column ordering for cursor pagination
|
|
||||||
if len(order_by) != 1:
|
|
||||||
raise ValueError(
|
|
||||||
f"Cursor pagination only supports single-column ordering, got {len(order_by)} columns"
|
|
||||||
)
|
|
||||||
|
|
||||||
cursor_key_column, cursor_id = cursor
|
|
||||||
order_column, order_direction = order_by[0]
|
|
||||||
|
|
||||||
# Verify cursor_key_column exists
|
|
||||||
if cursor_key_column not in table_obj.c:
|
|
||||||
raise ValueError(f"Cursor key column '{cursor_key_column}' not found in table '{table}'")
|
|
||||||
|
|
||||||
# Get cursor value for the order column
|
|
||||||
cursor_query = select(table_obj.c[order_column]).where(table_obj.c[cursor_key_column] == cursor_id)
|
|
||||||
cursor_result = await session.execute(cursor_query)
|
|
||||||
cursor_row = cursor_result.fetchone()
|
|
||||||
|
|
||||||
if not cursor_row:
|
|
||||||
raise ValueError(f"Record with {cursor_key_column}='{cursor_id}' not found in table '{table}'")
|
|
||||||
|
|
||||||
cursor_value = cursor_row[0]
|
|
||||||
|
|
||||||
# Apply cursor condition based on sort direction
|
|
||||||
if order_direction == "desc":
|
|
||||||
query = query.where(table_obj.c[order_column] < cursor_value)
|
|
||||||
else:
|
|
||||||
query = query.where(table_obj.c[order_column] > cursor_value)
|
|
||||||
|
|
||||||
# Apply ordering
|
|
||||||
if order_by:
|
|
||||||
if not isinstance(order_by, list):
|
|
||||||
raise ValueError(
|
|
||||||
f"order_by must be a list of tuples (column, order={['asc', 'desc']}), got {order_by}"
|
|
||||||
)
|
|
||||||
for order in order_by:
|
|
||||||
if not isinstance(order, tuple):
|
|
||||||
raise ValueError(
|
|
||||||
f"order_by must be a list of tuples (column, order={['asc', 'desc']}), got {order_by}"
|
|
||||||
)
|
|
||||||
name, order_type = order
|
|
||||||
if name not in table_obj.c:
|
|
||||||
raise ValueError(f"Column '{name}' not found in table '{table}'")
|
|
||||||
if order_type == "asc":
|
|
||||||
query = query.order_by(table_obj.c[name].asc())
|
|
||||||
elif order_type == "desc":
|
|
||||||
query = query.order_by(table_obj.c[name].desc())
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Invalid order '{order_type}' for column '{name}'")
|
|
||||||
|
|
||||||
# Fetch limit + 1 to determine has_more
|
|
||||||
fetch_limit = limit
|
|
||||||
if limit:
|
|
||||||
fetch_limit = limit + 1
|
|
||||||
|
|
||||||
if fetch_limit:
|
|
||||||
query = query.limit(fetch_limit)
|
|
||||||
|
|
||||||
result = await session.execute(query)
|
|
||||||
# Iterate directly - if no rows, list comprehension yields empty list
|
|
||||||
rows = [dict(row._mapping) for row in result]
|
|
||||||
|
|
||||||
# Always return pagination result
|
|
||||||
has_more = False
|
|
||||||
if limit and len(rows) > limit:
|
|
||||||
has_more = True
|
|
||||||
rows = rows[:limit]
|
|
||||||
|
|
||||||
return PaginatedResponse(data=rows, has_more=has_more)
|
|
||||||
|
|
||||||
async def fetch_one(
|
|
||||||
self,
|
|
||||||
table: str,
|
|
||||||
where: Mapping[str, Any] | None = None,
|
|
||||||
where_sql: str | None = None,
|
|
||||||
order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None,
|
|
||||||
) -> dict[str, Any] | None:
|
|
||||||
result = await self.fetch_all(table, where, where_sql, limit=1, order_by=order_by)
|
|
||||||
if not result.data:
|
|
||||||
return None
|
|
||||||
return result.data[0]
|
|
||||||
|
|
||||||
async def update(
|
|
||||||
self,
|
|
||||||
table: str,
|
|
||||||
data: Mapping[str, Any],
|
|
||||||
where: Mapping[str, Any],
|
|
||||||
) -> None:
|
|
||||||
if not where:
|
|
||||||
raise ValueError("where is required for update")
|
|
||||||
|
|
||||||
async with self.async_session() as session:
|
|
||||||
stmt = self.metadata.tables[table].update()
|
|
||||||
for key, value in where.items():
|
|
||||||
stmt = stmt.where(_build_where_expr(self.metadata.tables[table].c[key], value))
|
|
||||||
await session.execute(stmt, data)
|
|
||||||
await session.commit()
|
|
||||||
|
|
||||||
async def delete(self, table: str, where: Mapping[str, Any]) -> None:
|
|
||||||
if not where:
|
|
||||||
raise ValueError("where is required for delete")
|
|
||||||
|
|
||||||
async with self.async_session() as session:
|
|
||||||
stmt = self.metadata.tables[table].delete()
|
|
||||||
for key, value in where.items():
|
|
||||||
stmt = stmt.where(_build_where_expr(self.metadata.tables[table].c[key], value))
|
|
||||||
await session.execute(stmt)
|
|
||||||
await session.commit()
|
|
||||||
|
|
||||||
async def add_column_if_not_exists(
|
|
||||||
self,
|
|
||||||
table: str,
|
|
||||||
column_name: str,
|
|
||||||
column_type: ColumnType,
|
|
||||||
nullable: bool = True,
|
|
||||||
) -> None:
|
|
||||||
"""Add a column to an existing table if the column doesn't already exist."""
|
|
||||||
engine = self.create_engine()
|
|
||||||
|
|
||||||
try:
|
|
||||||
async with engine.begin() as conn:
|
|
||||||
|
|
||||||
def check_column_exists(sync_conn):
|
|
||||||
inspector = inspect(sync_conn)
|
|
||||||
|
|
||||||
table_names = inspector.get_table_names()
|
|
||||||
if table not in table_names:
|
|
||||||
return False, False # table doesn't exist, column doesn't exist
|
|
||||||
|
|
||||||
existing_columns = inspector.get_columns(table)
|
|
||||||
column_names = [col["name"] for col in existing_columns]
|
|
||||||
|
|
||||||
return True, column_name in column_names # table exists, column exists or not
|
|
||||||
|
|
||||||
table_exists, column_exists = await conn.run_sync(check_column_exists)
|
|
||||||
if not table_exists or column_exists:
|
|
||||||
return
|
|
||||||
|
|
||||||
sqlalchemy_type = TYPE_MAPPING.get(column_type)
|
|
||||||
if not sqlalchemy_type:
|
|
||||||
raise ValueError(f"Unsupported column type '{column_type}' for column '{column_name}'.")
|
|
||||||
|
|
||||||
# Create the ALTER TABLE statement
|
|
||||||
# Note: We need to get the dialect-specific type name
|
|
||||||
dialect = engine.dialect
|
|
||||||
type_impl = sqlalchemy_type()
|
|
||||||
compiled_type = type_impl.compile(dialect=dialect)
|
|
||||||
|
|
||||||
nullable_clause = "" if nullable else " NOT NULL"
|
|
||||||
add_column_sql = text(f"ALTER TABLE {table} ADD COLUMN {column_name} {compiled_type}{nullable_clause}")
|
|
||||||
|
|
||||||
await conn.execute(add_column_sql)
|
|
||||||
except Exception as e:
|
|
||||||
# If any error occurs during migration, log it but don't fail
|
|
||||||
# The table creation will handle adding the column
|
|
||||||
logger.error(f"Error adding column {column_name} to table {table}: {e}")
|
|
||||||
pass
|
|
||||||
|
|
||||||
def _get_dialect_insert(self, table: Table):
|
|
||||||
if self._is_sqlite_backend:
|
|
||||||
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
|
|
||||||
|
|
||||||
return sqlite_insert(table)
|
|
||||||
else:
|
|
||||||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
|
||||||
|
|
||||||
return pg_insert(table)
|
|
||||||
|
|
|
||||||
|
|
@ -4,85 +4,4 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from threading import Lock
|
from llama_stack.core.storage.sqlstore.sqlstore import * # noqa: F401,F403
|
||||||
from typing import Annotated, cast
|
|
||||||
|
|
||||||
from pydantic import Field
|
|
||||||
|
|
||||||
from llama_stack.core.storage.datatypes import (
|
|
||||||
PostgresSqlStoreConfig,
|
|
||||||
SqliteSqlStoreConfig,
|
|
||||||
SqlStoreReference,
|
|
||||||
StorageBackendConfig,
|
|
||||||
StorageBackendType,
|
|
||||||
)
|
|
||||||
|
|
||||||
from .api import SqlStore
|
|
||||||
|
|
||||||
sql_store_pip_packages = ["sqlalchemy[asyncio]", "aiosqlite", "asyncpg"]
|
|
||||||
|
|
||||||
_SQLSTORE_BACKENDS: dict[str, StorageBackendConfig] = {}
|
|
||||||
_SQLSTORE_INSTANCES: dict[str, SqlStore] = {}
|
|
||||||
_SQLSTORE_LOCKS: dict[str, Lock] = {}
|
|
||||||
|
|
||||||
|
|
||||||
SqlStoreConfig = Annotated[
|
|
||||||
SqliteSqlStoreConfig | PostgresSqlStoreConfig,
|
|
||||||
Field(discriminator="type"),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def get_pip_packages(store_config: dict | SqlStoreConfig) -> list[str]:
|
|
||||||
"""Get pip packages for SQL store config, handling both dict and object cases."""
|
|
||||||
if isinstance(store_config, dict):
|
|
||||||
store_type = store_config.get("type")
|
|
||||||
if store_type == StorageBackendType.SQL_SQLITE.value:
|
|
||||||
return SqliteSqlStoreConfig.pip_packages()
|
|
||||||
elif store_type == StorageBackendType.SQL_POSTGRES.value:
|
|
||||||
return PostgresSqlStoreConfig.pip_packages()
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unknown SQL store type: {store_type}")
|
|
||||||
else:
|
|
||||||
return store_config.pip_packages()
|
|
||||||
|
|
||||||
|
|
||||||
def sqlstore_impl(reference: SqlStoreReference) -> SqlStore:
|
|
||||||
backend_name = reference.backend
|
|
||||||
|
|
||||||
backend_config = _SQLSTORE_BACKENDS.get(backend_name)
|
|
||||||
if backend_config is None:
|
|
||||||
raise ValueError(
|
|
||||||
f"Unknown SQL store backend '{backend_name}'. Registered backends: {sorted(_SQLSTORE_BACKENDS)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
existing = _SQLSTORE_INSTANCES.get(backend_name)
|
|
||||||
if existing:
|
|
||||||
return existing
|
|
||||||
|
|
||||||
lock = _SQLSTORE_LOCKS.setdefault(backend_name, Lock())
|
|
||||||
with lock:
|
|
||||||
existing = _SQLSTORE_INSTANCES.get(backend_name)
|
|
||||||
if existing:
|
|
||||||
return existing
|
|
||||||
|
|
||||||
if isinstance(backend_config, SqliteSqlStoreConfig | PostgresSqlStoreConfig):
|
|
||||||
from .sqlalchemy_sqlstore import SqlAlchemySqlStoreImpl
|
|
||||||
|
|
||||||
config = cast(SqliteSqlStoreConfig | PostgresSqlStoreConfig, backend_config).model_copy()
|
|
||||||
instance = SqlAlchemySqlStoreImpl(config)
|
|
||||||
_SQLSTORE_INSTANCES[backend_name] = instance
|
|
||||||
return instance
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unknown sqlstore type {backend_config.type}")
|
|
||||||
|
|
||||||
|
|
||||||
def register_sqlstore_backends(backends: dict[str, StorageBackendConfig]) -> None:
|
|
||||||
"""Register the set of available SQL store backends for reference resolution."""
|
|
||||||
global _SQLSTORE_BACKENDS
|
|
||||||
global _SQLSTORE_INSTANCES
|
|
||||||
|
|
||||||
_SQLSTORE_BACKENDS.clear()
|
|
||||||
_SQLSTORE_INSTANCES.clear()
|
|
||||||
_SQLSTORE_LOCKS.clear()
|
|
||||||
for name, cfg in backends.items():
|
|
||||||
_SQLSTORE_BACKENDS[name] = cfg
|
|
||||||
|
|
|
||||||
3
src/llama_stack_api/internal/__init__.py
Normal file
3
src/llama_stack_api/internal/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
||||||
|
# Internal subpackage for shared interfaces that are not part of the public API.
|
||||||
|
|
||||||
|
__all__: list[str] = []
|
||||||
26
src/llama_stack_api/internal/kvstore.py
Normal file
26
src/llama_stack_api/internal/kvstore.py
Normal file
|
|
@ -0,0 +1,26 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Protocol
|
||||||
|
|
||||||
|
|
||||||
|
class KVStore(Protocol):
|
||||||
|
"""Protocol for simple key/value storage backends."""
|
||||||
|
|
||||||
|
# TODO: make the value type bytes instead of str
|
||||||
|
async def set(self, key: str, value: str, expiration: datetime | None = None) -> None: ...
|
||||||
|
|
||||||
|
async def get(self, key: str) -> str | None: ...
|
||||||
|
|
||||||
|
async def delete(self, key: str) -> None: ...
|
||||||
|
|
||||||
|
async def values_in_range(self, start_key: str, end_key: str) -> list[str]: ...
|
||||||
|
|
||||||
|
async def keys_in_range(self, start_key: str, end_key: str) -> list[str]: ...
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["KVStore"]
|
||||||
79
src/llama_stack_api/internal/sqlstore.py
Normal file
79
src/llama_stack_api/internal/sqlstore.py
Normal file
|
|
@ -0,0 +1,79 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from collections.abc import Mapping, Sequence
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any, Literal, Protocol
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from llama_stack_api import PaginatedResponse
|
||||||
|
|
||||||
|
|
||||||
|
class ColumnType(Enum):
|
||||||
|
INTEGER = "INTEGER"
|
||||||
|
STRING = "STRING"
|
||||||
|
TEXT = "TEXT"
|
||||||
|
FLOAT = "FLOAT"
|
||||||
|
BOOLEAN = "BOOLEAN"
|
||||||
|
JSON = "JSON"
|
||||||
|
DATETIME = "DATETIME"
|
||||||
|
|
||||||
|
|
||||||
|
class ColumnDefinition(BaseModel):
|
||||||
|
type: ColumnType
|
||||||
|
primary_key: bool = False
|
||||||
|
nullable: bool = True
|
||||||
|
default: Any = None
|
||||||
|
|
||||||
|
|
||||||
|
class SqlStore(Protocol):
|
||||||
|
"""Protocol for common SQL-store functionality."""
|
||||||
|
|
||||||
|
async def create_table(self, table: str, schema: Mapping[str, ColumnType | ColumnDefinition]) -> None: ...
|
||||||
|
|
||||||
|
async def insert(self, table: str, data: Mapping[str, Any] | Sequence[Mapping[str, Any]]) -> None: ...
|
||||||
|
|
||||||
|
async def upsert(
|
||||||
|
self,
|
||||||
|
table: str,
|
||||||
|
data: Mapping[str, Any],
|
||||||
|
conflict_columns: list[str],
|
||||||
|
update_columns: list[str] | None = None,
|
||||||
|
) -> None: ...
|
||||||
|
|
||||||
|
async def fetch_all(
|
||||||
|
self,
|
||||||
|
table: str,
|
||||||
|
where: Mapping[str, Any] | None = None,
|
||||||
|
where_sql: str | None = None,
|
||||||
|
limit: int | None = None,
|
||||||
|
order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None,
|
||||||
|
cursor: tuple[str, str] | None = None,
|
||||||
|
) -> PaginatedResponse: ...
|
||||||
|
|
||||||
|
async def fetch_one(
|
||||||
|
self,
|
||||||
|
table: str,
|
||||||
|
where: Mapping[str, Any] | None = None,
|
||||||
|
where_sql: str | None = None,
|
||||||
|
order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None,
|
||||||
|
) -> dict[str, Any] | None: ...
|
||||||
|
|
||||||
|
async def update(self, table: str, data: Mapping[str, Any], where: Mapping[str, Any]) -> None: ...
|
||||||
|
|
||||||
|
async def delete(self, table: str, where: Mapping[str, Any]) -> None: ...
|
||||||
|
|
||||||
|
async def add_column_if_not_exists(
|
||||||
|
self,
|
||||||
|
table: str,
|
||||||
|
column_name: str,
|
||||||
|
column_type: ColumnType,
|
||||||
|
nullable: bool = True,
|
||||||
|
) -> None: ...
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["ColumnDefinition", "ColumnType", "SqlStore"]
|
||||||
|
|
@ -175,7 +175,7 @@ def test_expires_after_requests(openai_client):
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.xfail(message="User isolation broken for current providers, must be fixed.")
|
@pytest.mark.xfail(message="User isolation broken for current providers, must be fixed.")
|
||||||
@patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user")
|
@patch("llama_stack.core.storage.sqlstore.authorized_sqlstore.get_authenticated_user")
|
||||||
def test_files_authentication_isolation(mock_get_authenticated_user, llama_stack_client):
|
def test_files_authentication_isolation(mock_get_authenticated_user, llama_stack_client):
|
||||||
"""Test that users can only access their own files."""
|
"""Test that users can only access their own files."""
|
||||||
from llama_stack_client import NotFoundError
|
from llama_stack_client import NotFoundError
|
||||||
|
|
@ -275,7 +275,7 @@ def test_files_authentication_isolation(mock_get_authenticated_user, llama_stack
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
|
||||||
@patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user")
|
@patch("llama_stack.core.storage.sqlstore.authorized_sqlstore.get_authenticated_user")
|
||||||
def test_files_authentication_shared_attributes(
|
def test_files_authentication_shared_attributes(
|
||||||
mock_get_authenticated_user, llama_stack_client, provider_type_is_openai
|
mock_get_authenticated_user, llama_stack_client, provider_type_is_openai
|
||||||
):
|
):
|
||||||
|
|
@ -335,7 +335,7 @@ def test_files_authentication_shared_attributes(
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
|
||||||
@patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user")
|
@patch("llama_stack.core.storage.sqlstore.authorized_sqlstore.get_authenticated_user")
|
||||||
def test_files_authentication_anonymous_access(
|
def test_files_authentication_anonymous_access(
|
||||||
mock_get_authenticated_user, llama_stack_client, provider_type_is_openai
|
mock_get_authenticated_user, llama_stack_client, provider_type_is_openai
|
||||||
):
|
):
|
||||||
|
|
|
||||||
|
|
@ -13,14 +13,14 @@ import pytest
|
||||||
from llama_stack.core.access_control.access_control import default_policy
|
from llama_stack.core.access_control.access_control import default_policy
|
||||||
from llama_stack.core.datatypes import User
|
from llama_stack.core.datatypes import User
|
||||||
from llama_stack.core.storage.datatypes import SqlStoreReference
|
from llama_stack.core.storage.datatypes import SqlStoreReference
|
||||||
from llama_stack.providers.utils.sqlstore.api import ColumnType
|
from llama_stack.core.storage.sqlstore.authorized_sqlstore import AuthorizedSqlStore
|
||||||
from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore
|
from llama_stack.core.storage.sqlstore.sqlstore import (
|
||||||
from llama_stack.providers.utils.sqlstore.sqlstore import (
|
|
||||||
PostgresSqlStoreConfig,
|
PostgresSqlStoreConfig,
|
||||||
SqliteSqlStoreConfig,
|
SqliteSqlStoreConfig,
|
||||||
register_sqlstore_backends,
|
register_sqlstore_backends,
|
||||||
sqlstore_impl,
|
sqlstore_impl,
|
||||||
)
|
)
|
||||||
|
from llama_stack_api.internal.sqlstore import ColumnType
|
||||||
|
|
||||||
|
|
||||||
def get_postgres_config():
|
def get_postgres_config():
|
||||||
|
|
@ -96,7 +96,7 @@ async def cleanup_records(sql_store, table_name, record_ids):
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("backend_config", BACKEND_CONFIGS)
|
@pytest.mark.parametrize("backend_config", BACKEND_CONFIGS)
|
||||||
@patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user")
|
@patch("llama_stack.core.storage.sqlstore.authorized_sqlstore.get_authenticated_user")
|
||||||
async def test_authorized_store_attributes(mock_get_authenticated_user, authorized_store, request):
|
async def test_authorized_store_attributes(mock_get_authenticated_user, authorized_store, request):
|
||||||
"""Test that JSON column comparisons work correctly for both PostgreSQL and SQLite"""
|
"""Test that JSON column comparisons work correctly for both PostgreSQL and SQLite"""
|
||||||
backend_name = request.node.callspec.id
|
backend_name = request.node.callspec.id
|
||||||
|
|
@ -190,7 +190,7 @@ async def test_authorized_store_attributes(mock_get_authenticated_user, authoriz
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("backend_config", BACKEND_CONFIGS)
|
@pytest.mark.parametrize("backend_config", BACKEND_CONFIGS)
|
||||||
@patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user")
|
@patch("llama_stack.core.storage.sqlstore.authorized_sqlstore.get_authenticated_user")
|
||||||
async def test_user_ownership_policy(mock_get_authenticated_user, authorized_store, request):
|
async def test_user_ownership_policy(mock_get_authenticated_user, authorized_store, request):
|
||||||
"""Test that 'user is owner' policies work correctly with record ownership"""
|
"""Test that 'user is owner' policies work correctly with record ownership"""
|
||||||
from llama_stack.core.access_control.datatypes import AccessRule, Action, Scope
|
from llama_stack.core.access_control.datatypes import AccessRule, Action, Scope
|
||||||
|
|
|
||||||
|
|
@ -23,7 +23,7 @@ from llama_stack.core.storage.datatypes import (
|
||||||
SqlStoreReference,
|
SqlStoreReference,
|
||||||
StorageConfig,
|
StorageConfig,
|
||||||
)
|
)
|
||||||
from llama_stack.providers.utils.sqlstore.sqlstore import register_sqlstore_backends
|
from llama_stack.core.storage.sqlstore.sqlstore import register_sqlstore_backends
|
||||||
from llama_stack_api import OpenAIResponseInputMessageContentText, OpenAIResponseMessage
|
from llama_stack_api import OpenAIResponseInputMessageContentText, OpenAIResponseMessage
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -13,7 +13,7 @@ from llama_stack.providers.inline.files.localfs import (
|
||||||
LocalfsFilesImpl,
|
LocalfsFilesImpl,
|
||||||
LocalfsFilesImplConfig,
|
LocalfsFilesImplConfig,
|
||||||
)
|
)
|
||||||
from llama_stack.providers.utils.sqlstore.sqlstore import register_sqlstore_backends
|
from llama_stack.core.storage.sqlstore.sqlstore import register_sqlstore_backends
|
||||||
from llama_stack_api import OpenAIFilePurpose, Order, ResourceNotFoundError
|
from llama_stack_api import OpenAIFilePurpose, Order, ResourceNotFoundError
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -7,8 +7,8 @@
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from llama_stack.core.store.registry import CachedDiskDistributionRegistry, DiskDistributionRegistry
|
from llama_stack.core.store.registry import CachedDiskDistributionRegistry, DiskDistributionRegistry
|
||||||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
from llama_stack.core.storage.kvstore.config import SqliteKVStoreConfig
|
||||||
from llama_stack.providers.utils.kvstore.sqlite import SqliteKVStoreImpl
|
from llama_stack.core.storage.kvstore.sqlite import SqliteKVStoreImpl
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="function")
|
@pytest.fixture(scope="function")
|
||||||
|
|
|
||||||
|
|
@ -18,7 +18,7 @@ from llama_stack.core.storage.datatypes import (
|
||||||
SqlStoreReference,
|
SqlStoreReference,
|
||||||
StorageConfig,
|
StorageConfig,
|
||||||
)
|
)
|
||||||
from llama_stack.providers.utils.kvstore import register_kvstore_backends
|
from llama_stack.core.storage.kvstore import register_kvstore_backends
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
|
|
||||||
|
|
@ -24,7 +24,7 @@ from llama_stack.providers.utils.responses.responses_store import (
|
||||||
ResponsesStore,
|
ResponsesStore,
|
||||||
_OpenAIResponseObjectWithInputAndMessages,
|
_OpenAIResponseObjectWithInputAndMessages,
|
||||||
)
|
)
|
||||||
from llama_stack.providers.utils.sqlstore.sqlstore import register_sqlstore_backends
|
from llama_stack.core.storage.sqlstore.sqlstore import register_sqlstore_backends
|
||||||
from llama_stack_api.agents import Order
|
from llama_stack_api.agents import Order
|
||||||
from llama_stack_api.inference import (
|
from llama_stack_api.inference import (
|
||||||
OpenAIAssistantMessageParam,
|
OpenAIAssistantMessageParam,
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,7 @@ import pytest
|
||||||
from llama_stack.core.storage.datatypes import KVStoreReference, SqliteKVStoreConfig
|
from llama_stack.core.storage.datatypes import KVStoreReference, SqliteKVStoreConfig
|
||||||
from llama_stack.providers.inline.batches.reference.batches import ReferenceBatchesImpl
|
from llama_stack.providers.inline.batches.reference.batches import ReferenceBatchesImpl
|
||||||
from llama_stack.providers.inline.batches.reference.config import ReferenceBatchesImplConfig
|
from llama_stack.providers.inline.batches.reference.config import ReferenceBatchesImplConfig
|
||||||
from llama_stack.providers.utils.kvstore import kvstore_impl, register_kvstore_backends
|
from llama_stack.core.storage.kvstore import kvstore_impl, register_kvstore_backends
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,7 @@ from moto import mock_aws
|
||||||
|
|
||||||
from llama_stack.core.storage.datatypes import SqliteSqlStoreConfig, SqlStoreReference
|
from llama_stack.core.storage.datatypes import SqliteSqlStoreConfig, SqlStoreReference
|
||||||
from llama_stack.providers.remote.files.s3 import S3FilesImplConfig, get_adapter_impl
|
from llama_stack.providers.remote.files.s3 import S3FilesImplConfig, get_adapter_impl
|
||||||
from llama_stack.providers.utils.sqlstore.sqlstore import register_sqlstore_backends
|
from llama_stack.core.storage.sqlstore.sqlstore import register_sqlstore_backends
|
||||||
|
|
||||||
|
|
||||||
class MockUploadFile:
|
class MockUploadFile:
|
||||||
|
|
|
||||||
|
|
@ -18,11 +18,11 @@ async def test_listing_hides_other_users_file(s3_provider, sample_text_file):
|
||||||
user_a = User("user-a", {"roles": ["team-a"]})
|
user_a = User("user-a", {"roles": ["team-a"]})
|
||||||
user_b = User("user-b", {"roles": ["team-b"]})
|
user_b = User("user-b", {"roles": ["team-b"]})
|
||||||
|
|
||||||
with patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user") as mock_get_user:
|
with patch("llama_stack.core.storage.sqlstore.authorized_sqlstore.get_authenticated_user") as mock_get_user:
|
||||||
mock_get_user.return_value = user_a
|
mock_get_user.return_value = user_a
|
||||||
uploaded = await s3_provider.openai_upload_file(file=sample_text_file, purpose=OpenAIFilePurpose.ASSISTANTS)
|
uploaded = await s3_provider.openai_upload_file(file=sample_text_file, purpose=OpenAIFilePurpose.ASSISTANTS)
|
||||||
|
|
||||||
with patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user") as mock_get_user:
|
with patch("llama_stack.core.storage.sqlstore.authorized_sqlstore.get_authenticated_user") as mock_get_user:
|
||||||
mock_get_user.return_value = user_b
|
mock_get_user.return_value = user_b
|
||||||
listed = await s3_provider.openai_list_files()
|
listed = await s3_provider.openai_list_files()
|
||||||
assert all(f.id != uploaded.id for f in listed.data)
|
assert all(f.id != uploaded.id for f in listed.data)
|
||||||
|
|
@ -41,11 +41,11 @@ async def test_cannot_access_other_user_file(s3_provider, sample_text_file, op):
|
||||||
user_a = User("user-a", {"roles": ["team-a"]})
|
user_a = User("user-a", {"roles": ["team-a"]})
|
||||||
user_b = User("user-b", {"roles": ["team-b"]})
|
user_b = User("user-b", {"roles": ["team-b"]})
|
||||||
|
|
||||||
with patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user") as mock_get_user:
|
with patch("llama_stack.core.storage.sqlstore.authorized_sqlstore.get_authenticated_user") as mock_get_user:
|
||||||
mock_get_user.return_value = user_a
|
mock_get_user.return_value = user_a
|
||||||
uploaded = await s3_provider.openai_upload_file(file=sample_text_file, purpose=OpenAIFilePurpose.ASSISTANTS)
|
uploaded = await s3_provider.openai_upload_file(file=sample_text_file, purpose=OpenAIFilePurpose.ASSISTANTS)
|
||||||
|
|
||||||
with patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user") as mock_get_user:
|
with patch("llama_stack.core.storage.sqlstore.authorized_sqlstore.get_authenticated_user") as mock_get_user:
|
||||||
mock_get_user.return_value = user_b
|
mock_get_user.return_value = user_b
|
||||||
with pytest.raises(ResourceNotFoundError):
|
with pytest.raises(ResourceNotFoundError):
|
||||||
await op(s3_provider, uploaded.id)
|
await op(s3_provider, uploaded.id)
|
||||||
|
|
@ -56,11 +56,11 @@ async def test_shared_role_allows_listing(s3_provider, sample_text_file):
|
||||||
user_a = User("user-a", {"roles": ["shared-role"]})
|
user_a = User("user-a", {"roles": ["shared-role"]})
|
||||||
user_b = User("user-b", {"roles": ["shared-role"]})
|
user_b = User("user-b", {"roles": ["shared-role"]})
|
||||||
|
|
||||||
with patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user") as mock_get_user:
|
with patch("llama_stack.core.storage.sqlstore.authorized_sqlstore.get_authenticated_user") as mock_get_user:
|
||||||
mock_get_user.return_value = user_a
|
mock_get_user.return_value = user_a
|
||||||
uploaded = await s3_provider.openai_upload_file(file=sample_text_file, purpose=OpenAIFilePurpose.ASSISTANTS)
|
uploaded = await s3_provider.openai_upload_file(file=sample_text_file, purpose=OpenAIFilePurpose.ASSISTANTS)
|
||||||
|
|
||||||
with patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user") as mock_get_user:
|
with patch("llama_stack.core.storage.sqlstore.authorized_sqlstore.get_authenticated_user") as mock_get_user:
|
||||||
mock_get_user.return_value = user_b
|
mock_get_user.return_value = user_b
|
||||||
listed = await s3_provider.openai_list_files()
|
listed = await s3_provider.openai_list_files()
|
||||||
assert any(f.id == uploaded.id for f in listed.data)
|
assert any(f.id == uploaded.id for f in listed.data)
|
||||||
|
|
@ -79,10 +79,10 @@ async def test_shared_role_allows_access(s3_provider, sample_text_file, op):
|
||||||
user_x = User("user-x", {"roles": ["shared-role"]})
|
user_x = User("user-x", {"roles": ["shared-role"]})
|
||||||
user_y = User("user-y", {"roles": ["shared-role"]})
|
user_y = User("user-y", {"roles": ["shared-role"]})
|
||||||
|
|
||||||
with patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user") as mock_get_user:
|
with patch("llama_stack.core.storage.sqlstore.authorized_sqlstore.get_authenticated_user") as mock_get_user:
|
||||||
mock_get_user.return_value = user_x
|
mock_get_user.return_value = user_x
|
||||||
uploaded = await s3_provider.openai_upload_file(file=sample_text_file, purpose=OpenAIFilePurpose.ASSISTANTS)
|
uploaded = await s3_provider.openai_upload_file(file=sample_text_file, purpose=OpenAIFilePurpose.ASSISTANTS)
|
||||||
|
|
||||||
with patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user") as mock_get_user:
|
with patch("llama_stack.core.storage.sqlstore.authorized_sqlstore.get_authenticated_user") as mock_get_user:
|
||||||
mock_get_user.return_value = user_y
|
mock_get_user.return_value = user_y
|
||||||
await op(s3_provider, uploaded.id)
|
await op(s3_provider, uploaded.id)
|
||||||
|
|
|
||||||
|
|
@ -17,7 +17,7 @@ from llama_stack.providers.inline.vector_io.sqlite_vec import SQLiteVectorIOConf
|
||||||
from llama_stack.providers.inline.vector_io.sqlite_vec.sqlite_vec import SQLiteVecIndex, SQLiteVecVectorIOAdapter
|
from llama_stack.providers.inline.vector_io.sqlite_vec.sqlite_vec import SQLiteVecIndex, SQLiteVecVectorIOAdapter
|
||||||
from llama_stack.providers.remote.vector_io.pgvector.config import PGVectorVectorIOConfig
|
from llama_stack.providers.remote.vector_io.pgvector.config import PGVectorVectorIOConfig
|
||||||
from llama_stack.providers.remote.vector_io.pgvector.pgvector import PGVectorIndex, PGVectorVectorIOAdapter
|
from llama_stack.providers.remote.vector_io.pgvector.pgvector import PGVectorIndex, PGVectorVectorIOAdapter
|
||||||
from llama_stack.providers.utils.kvstore import register_kvstore_backends
|
from llama_stack.core.storage.kvstore import register_kvstore_backends
|
||||||
from llama_stack_api import Chunk, ChunkMetadata, QueryChunksResponse, VectorStore
|
from llama_stack_api import Chunk, ChunkMetadata, QueryChunksResponse, VectorStore
|
||||||
|
|
||||||
EMBEDDING_DIMENSION = 768
|
EMBEDDING_DIMENSION = 768
|
||||||
|
|
@ -279,7 +279,7 @@ async def pgvector_vec_adapter(unique_kvstore_config, mock_inference_api, embedd
|
||||||
) as mock_check_version:
|
) as mock_check_version:
|
||||||
mock_check_version.return_value = "0.5.1"
|
mock_check_version.return_value = "0.5.1"
|
||||||
|
|
||||||
with patch("llama_stack.providers.utils.kvstore.kvstore_impl") as mock_kvstore_impl:
|
with patch("llama_stack.core.storage.kvstore.kvstore_impl") as mock_kvstore_impl:
|
||||||
mock_kvstore = AsyncMock()
|
mock_kvstore = AsyncMock()
|
||||||
mock_kvstore_impl.return_value = mock_kvstore
|
mock_kvstore_impl.return_value = mock_kvstore
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -14,7 +14,7 @@ from llama_stack.core.store.registry import (
|
||||||
CachedDiskDistributionRegistry,
|
CachedDiskDistributionRegistry,
|
||||||
DiskDistributionRegistry,
|
DiskDistributionRegistry,
|
||||||
)
|
)
|
||||||
from llama_stack.providers.utils.kvstore import kvstore_impl, register_kvstore_backends
|
from llama_stack.core.storage.kvstore import kvstore_impl, register_kvstore_backends
|
||||||
from llama_stack_api import Model, VectorStore
|
from llama_stack_api import Model, VectorStore
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,7 @@ from starlette.middleware.base import BaseHTTPMiddleware
|
||||||
from llama_stack.core.datatypes import QuotaConfig, QuotaPeriod
|
from llama_stack.core.datatypes import QuotaConfig, QuotaPeriod
|
||||||
from llama_stack.core.server.quota import QuotaMiddleware
|
from llama_stack.core.server.quota import QuotaMiddleware
|
||||||
from llama_stack.core.storage.datatypes import KVStoreReference, SqliteKVStoreConfig
|
from llama_stack.core.storage.datatypes import KVStoreReference, SqliteKVStoreConfig
|
||||||
from llama_stack.providers.utils.kvstore import register_kvstore_backends
|
from llama_stack.core.storage.kvstore import register_kvstore_backends
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
|
|
||||||
|
|
@ -24,8 +24,8 @@ from llama_stack.core.storage.datatypes import (
|
||||||
SqlStoreReference,
|
SqlStoreReference,
|
||||||
StorageConfig,
|
StorageConfig,
|
||||||
)
|
)
|
||||||
from llama_stack.providers.utils.kvstore import register_kvstore_backends
|
from llama_stack.core.storage.kvstore import register_kvstore_backends
|
||||||
from llama_stack.providers.utils.sqlstore.sqlstore import register_sqlstore_backends
|
from llama_stack.core.storage.sqlstore.sqlstore import register_sqlstore_backends
|
||||||
from llama_stack_api import Inference, InlineProviderSpec, ProviderSpec
|
from llama_stack_api import Inference, InlineProviderSpec, ProviderSpec
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,7 @@ import pytest
|
||||||
|
|
||||||
from llama_stack.core.storage.datatypes import InferenceStoreReference, SqliteSqlStoreConfig
|
from llama_stack.core.storage.datatypes import InferenceStoreReference, SqliteSqlStoreConfig
|
||||||
from llama_stack.providers.utils.inference.inference_store import InferenceStore
|
from llama_stack.providers.utils.inference.inference_store import InferenceStore
|
||||||
from llama_stack.providers.utils.sqlstore.sqlstore import register_sqlstore_backends
|
from llama_stack.core.storage.sqlstore.sqlstore import register_sqlstore_backends
|
||||||
from llama_stack_api import (
|
from llama_stack_api import (
|
||||||
OpenAIAssistantMessageParam,
|
OpenAIAssistantMessageParam,
|
||||||
OpenAIChatCompletion,
|
OpenAIChatCompletion,
|
||||||
|
|
|
||||||
|
|
@ -5,8 +5,8 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
|
||||||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
from llama_stack.core.storage.kvstore.config import SqliteKVStoreConfig
|
||||||
from llama_stack.providers.utils.kvstore.sqlite.sqlite import SqliteKVStoreImpl
|
from llama_stack.core.storage.kvstore.sqlite.sqlite import SqliteKVStoreImpl
|
||||||
|
|
||||||
|
|
||||||
async def test_memory_kvstore_persistence_behavior():
|
async def test_memory_kvstore_persistence_behavior():
|
||||||
|
|
|
||||||
|
|
@ -12,7 +12,7 @@ import pytest
|
||||||
|
|
||||||
from llama_stack.core.storage.datatypes import ResponsesStoreReference, SqliteSqlStoreConfig
|
from llama_stack.core.storage.datatypes import ResponsesStoreReference, SqliteSqlStoreConfig
|
||||||
from llama_stack.providers.utils.responses.responses_store import ResponsesStore
|
from llama_stack.providers.utils.responses.responses_store import ResponsesStore
|
||||||
from llama_stack.providers.utils.sqlstore.sqlstore import register_sqlstore_backends
|
from llama_stack.core.storage.sqlstore.sqlstore import register_sqlstore_backends
|
||||||
from llama_stack_api import OpenAIMessageParam, OpenAIResponseInput, OpenAIResponseObject, OpenAIUserMessageParam, Order
|
from llama_stack_api import OpenAIMessageParam, OpenAIResponseInput, OpenAIResponseObject, OpenAIUserMessageParam, Order
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -9,9 +9,9 @@ from tempfile import TemporaryDirectory
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType
|
from llama_stack.core.storage.sqlstore.sqlalchemy_sqlstore import SqlAlchemySqlStoreImpl
|
||||||
from llama_stack.providers.utils.sqlstore.sqlalchemy_sqlstore import SqlAlchemySqlStoreImpl
|
from llama_stack.core.storage.sqlstore.sqlstore import SqliteSqlStoreConfig
|
||||||
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
|
from llama_stack_api.internal.sqlstore import ColumnDefinition, ColumnType
|
||||||
|
|
||||||
|
|
||||||
async def test_sqlite_sqlstore():
|
async def test_sqlite_sqlstore():
|
||||||
|
|
|
||||||
|
|
@ -10,13 +10,13 @@ from unittest.mock import patch
|
||||||
from llama_stack.core.access_control.access_control import default_policy, is_action_allowed
|
from llama_stack.core.access_control.access_control import default_policy, is_action_allowed
|
||||||
from llama_stack.core.access_control.datatypes import Action
|
from llama_stack.core.access_control.datatypes import Action
|
||||||
from llama_stack.core.datatypes import User
|
from llama_stack.core.datatypes import User
|
||||||
from llama_stack.providers.utils.sqlstore.api import ColumnType
|
from llama_stack.core.storage.sqlstore.authorized_sqlstore import AuthorizedSqlStore, SqlRecord
|
||||||
from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore, SqlRecord
|
from llama_stack.core.storage.sqlstore.sqlalchemy_sqlstore import SqlAlchemySqlStoreImpl
|
||||||
from llama_stack.providers.utils.sqlstore.sqlalchemy_sqlstore import SqlAlchemySqlStoreImpl
|
from llama_stack.core.storage.sqlstore.sqlstore import SqliteSqlStoreConfig
|
||||||
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
|
from llama_stack_api.internal.sqlstore import ColumnType
|
||||||
|
|
||||||
|
|
||||||
@patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user")
|
@patch("llama_stack.core.storage.sqlstore.authorized_sqlstore.get_authenticated_user")
|
||||||
async def test_authorized_fetch_with_where_sql_access_control(mock_get_authenticated_user):
|
async def test_authorized_fetch_with_where_sql_access_control(mock_get_authenticated_user):
|
||||||
"""Test that fetch_all works correctly with where_sql for access control"""
|
"""Test that fetch_all works correctly with where_sql for access control"""
|
||||||
with TemporaryDirectory() as tmp_dir:
|
with TemporaryDirectory() as tmp_dir:
|
||||||
|
|
@ -78,7 +78,7 @@ async def test_authorized_fetch_with_where_sql_access_control(mock_get_authentic
|
||||||
assert row["title"] == "User Document"
|
assert row["title"] == "User Document"
|
||||||
|
|
||||||
|
|
||||||
@patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user")
|
@patch("llama_stack.core.storage.sqlstore.authorized_sqlstore.get_authenticated_user")
|
||||||
async def test_sql_policy_consistency(mock_get_authenticated_user):
|
async def test_sql_policy_consistency(mock_get_authenticated_user):
|
||||||
"""Test that SQL WHERE clause logic exactly matches is_action_allowed policy logic"""
|
"""Test that SQL WHERE clause logic exactly matches is_action_allowed policy logic"""
|
||||||
with TemporaryDirectory() as tmp_dir:
|
with TemporaryDirectory() as tmp_dir:
|
||||||
|
|
@ -164,7 +164,7 @@ async def test_sql_policy_consistency(mock_get_authenticated_user):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user")
|
@patch("llama_stack.core.storage.sqlstore.authorized_sqlstore.get_authenticated_user")
|
||||||
async def test_authorized_store_user_attribute_capture(mock_get_authenticated_user):
|
async def test_authorized_store_user_attribute_capture(mock_get_authenticated_user):
|
||||||
"""Test that user attributes are properly captured during insert"""
|
"""Test that user attributes are properly captured during insert"""
|
||||||
with TemporaryDirectory() as tmp_dir:
|
with TemporaryDirectory() as tmp_dir:
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue