From fcf649b97a8bb99f52097b558acfe2d0285f4ef3 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Wed, 12 Nov 2025 12:14:26 -0800 Subject: [PATCH] feat(storage): share sql/kv instances and add upsert support (#4140) A few changes to the storage layer to ensure we reduce unnecessary contention arising out of our design choices (and letting the database layer do its correct thing): - SQL stores now share a single `SqlAlchemySqlStoreImpl` per backend, and `kvstore_impl` caches instances per `(backend, namespace)`. This avoids spawning multiple SQLite connections for the same file, reducing lock contention and aligning the cache story for all backends. - Added an async upsert API (with SQLite/Postgres dialect inserts) and routed it through `AuthorizedSqlStore`, then switched conversations and responses to call it. Using native `ON CONFLICT DO UPDATE` eliminates the insert-then-update retry window that previously caused long WAL lock retries. ### Test Plan Existing tests, added a unit test for `upsert()` --- .../core/conversations/conversations.py | 15 ++--- .../providers/utils/kvstore/kvstore.py | 57 +++++++++++++------ .../utils/responses/responses_store.py | 19 ++----- .../providers/utils/sqlstore/api.py | 12 ++++ .../utils/sqlstore/authorized_sqlstore.py | 17 ++++++ .../utils/sqlstore/sqlalchemy_sqlstore.py | 39 ++++++++++++- .../providers/utils/sqlstore/sqlstore.py | 30 ++++++++-- tests/unit/utils/sqlstore/test_sqlstore.py | 34 ++++++++++- 8 files changed, 172 insertions(+), 51 deletions(-) diff --git a/src/llama_stack/core/conversations/conversations.py b/src/llama_stack/core/conversations/conversations.py index 951de5e9d..f83834522 100644 --- a/src/llama_stack/core/conversations/conversations.py +++ b/src/llama_stack/core/conversations/conversations.py @@ -203,16 +203,11 @@ class ConversationServiceImpl(Conversations): "item_data": item_dict, } - # TODO: Add support for upsert in sql_store, this will fail first if ID exists and then update - try: - await self.sql_store.insert(table="conversation_items", data=item_record) - except Exception: - # If insert fails due to ID conflict, update existing record - await self.sql_store.update( - table="conversation_items", - data={"created_at": created_at, "item_data": item_dict}, - where={"id": item_id}, - ) + await self.sql_store.upsert( + table="conversation_items", + data=item_record, + conflict_columns=["id"], + ) created_items.append(item_dict) diff --git a/src/llama_stack/providers/utils/kvstore/kvstore.py b/src/llama_stack/providers/utils/kvstore/kvstore.py index eee51e5d9..5b8d77102 100644 --- a/src/llama_stack/providers/utils/kvstore/kvstore.py +++ b/src/llama_stack/providers/utils/kvstore/kvstore.py @@ -11,6 +11,9 @@ from __future__ import annotations +import asyncio +from collections import defaultdict + from llama_stack.core.storage.datatypes import KVStoreReference, StorageBackendConfig, StorageBackendType from .api import KVStore @@ -53,45 +56,63 @@ class InmemoryKVStoreImpl(KVStore): _KVSTORE_BACKENDS: dict[str, KVStoreConfig] = {} +_KVSTORE_INSTANCES: dict[tuple[str, str], KVStore] = {} +_KVSTORE_LOCKS: defaultdict[tuple[str, str], asyncio.Lock] = defaultdict(asyncio.Lock) def register_kvstore_backends(backends: dict[str, StorageBackendConfig]) -> None: """Register the set of available KV store backends for reference resolution.""" global _KVSTORE_BACKENDS + global _KVSTORE_INSTANCES + global _KVSTORE_LOCKS _KVSTORE_BACKENDS.clear() + _KVSTORE_INSTANCES.clear() + _KVSTORE_LOCKS.clear() for name, cfg in backends.items(): _KVSTORE_BACKENDS[name] = cfg async def kvstore_impl(reference: KVStoreReference) -> KVStore: backend_name = reference.backend + cache_key = (backend_name, reference.namespace) + + existing = _KVSTORE_INSTANCES.get(cache_key) + if existing: + return existing backend_config = _KVSTORE_BACKENDS.get(backend_name) if backend_config is None: raise ValueError(f"Unknown KVStore backend '{backend_name}'. Registered backends: {sorted(_KVSTORE_BACKENDS)}") - config = backend_config.model_copy() - config.namespace = reference.namespace + lock = _KVSTORE_LOCKS[cache_key] + async with lock: + existing = _KVSTORE_INSTANCES.get(cache_key) + if existing: + return existing - if config.type == StorageBackendType.KV_REDIS.value: - from .redis import RedisKVStoreImpl + config = backend_config.model_copy() + config.namespace = reference.namespace - impl = RedisKVStoreImpl(config) - elif config.type == StorageBackendType.KV_SQLITE.value: - from .sqlite import SqliteKVStoreImpl + if config.type == StorageBackendType.KV_REDIS.value: + from .redis import RedisKVStoreImpl - impl = SqliteKVStoreImpl(config) - elif config.type == StorageBackendType.KV_POSTGRES.value: - from .postgres import PostgresKVStoreImpl + impl = RedisKVStoreImpl(config) + elif config.type == StorageBackendType.KV_SQLITE.value: + from .sqlite import SqliteKVStoreImpl - impl = PostgresKVStoreImpl(config) - elif config.type == StorageBackendType.KV_MONGODB.value: - from .mongodb import MongoDBKVStoreImpl + impl = SqliteKVStoreImpl(config) + elif config.type == StorageBackendType.KV_POSTGRES.value: + from .postgres import PostgresKVStoreImpl - impl = MongoDBKVStoreImpl(config) - else: - raise ValueError(f"Unknown kvstore type {config.type}") + impl = PostgresKVStoreImpl(config) + elif config.type == StorageBackendType.KV_MONGODB.value: + from .mongodb import MongoDBKVStoreImpl - await impl.initialize() - return impl + impl = MongoDBKVStoreImpl(config) + else: + raise ValueError(f"Unknown kvstore type {config.type}") + + await impl.initialize() + _KVSTORE_INSTANCES[cache_key] = impl + return impl diff --git a/src/llama_stack/providers/utils/responses/responses_store.py b/src/llama_stack/providers/utils/responses/responses_store.py index f5024a9ed..fdca8ddee 100644 --- a/src/llama_stack/providers/utils/responses/responses_store.py +++ b/src/llama_stack/providers/utils/responses/responses_store.py @@ -252,19 +252,12 @@ class ResponsesStore: # Serialize messages to dict format for JSON storage messages_data = [msg.model_dump() for msg in messages] - # Upsert: try insert first, update if exists - try: - await self.sql_store.insert( - table="conversation_messages", - data={"conversation_id": conversation_id, "messages": messages_data}, - ) - except Exception: - # If insert fails due to ID conflict, update existing record - await self.sql_store.update( - table="conversation_messages", - data={"messages": messages_data}, - where={"conversation_id": conversation_id}, - ) + await self.sql_store.upsert( + table="conversation_messages", + data={"conversation_id": conversation_id, "messages": messages_data}, + conflict_columns=["conversation_id"], + update_columns=["messages"], + ) logger.debug(f"Stored {len(messages)} messages for conversation {conversation_id}") diff --git a/src/llama_stack/providers/utils/sqlstore/api.py b/src/llama_stack/providers/utils/sqlstore/api.py index a61fd1090..bcd224234 100644 --- a/src/llama_stack/providers/utils/sqlstore/api.py +++ b/src/llama_stack/providers/utils/sqlstore/api.py @@ -47,6 +47,18 @@ class SqlStore(Protocol): """ pass + async def upsert( + self, + table: str, + data: Mapping[str, Any], + conflict_columns: list[str], + update_columns: list[str] | None = None, + ) -> None: + """ + Insert a row and update specified columns when conflicts occur. + """ + pass + async def fetch_all( self, table: str, diff --git a/src/llama_stack/providers/utils/sqlstore/authorized_sqlstore.py b/src/llama_stack/providers/utils/sqlstore/authorized_sqlstore.py index eb2d9a491..ba95dd120 100644 --- a/src/llama_stack/providers/utils/sqlstore/authorized_sqlstore.py +++ b/src/llama_stack/providers/utils/sqlstore/authorized_sqlstore.py @@ -129,6 +129,23 @@ class AuthorizedSqlStore: enhanced_data = [_enhance_item_with_access_control(item, current_user) for item in data] await self.sql_store.insert(table, enhanced_data) + async def upsert( + self, + table: str, + data: Mapping[str, Any], + conflict_columns: list[str], + update_columns: list[str] | None = None, + ) -> None: + """Upsert a row with automatic access control attribute capture.""" + current_user = get_authenticated_user() + enhanced_data = _enhance_item_with_access_control(data, current_user) + await self.sql_store.upsert( + table=table, + data=enhanced_data, + conflict_columns=conflict_columns, + update_columns=update_columns, + ) + async def fetch_all( self, table: str, diff --git a/src/llama_stack/providers/utils/sqlstore/sqlalchemy_sqlstore.py b/src/llama_stack/providers/utils/sqlstore/sqlalchemy_sqlstore.py index 356f49ed1..cfc3131f4 100644 --- a/src/llama_stack/providers/utils/sqlstore/sqlalchemy_sqlstore.py +++ b/src/llama_stack/providers/utils/sqlstore/sqlalchemy_sqlstore.py @@ -72,13 +72,14 @@ def _build_where_expr(column: ColumnElement, value: Any) -> ColumnElement: class SqlAlchemySqlStoreImpl(SqlStore): def __init__(self, config: SqlAlchemySqlStoreConfig): self.config = config + self._is_sqlite_backend = "sqlite" in self.config.engine_str self.async_session = async_sessionmaker(self.create_engine()) self.metadata = MetaData() def create_engine(self) -> AsyncEngine: # Configure connection args for better concurrency support connect_args = {} - if "sqlite" in self.config.engine_str: + if self._is_sqlite_backend: # SQLite-specific optimizations for concurrent access # With WAL mode, most locks resolve in milliseconds, but allow up to 5s for edge cases connect_args["timeout"] = 5.0 @@ -91,7 +92,7 @@ class SqlAlchemySqlStoreImpl(SqlStore): ) # Enable WAL mode for SQLite to support concurrent readers and writers - if "sqlite" in self.config.engine_str: + if self._is_sqlite_backend: @event.listens_for(engine.sync_engine, "connect") def set_sqlite_pragma(dbapi_conn, connection_record): @@ -151,6 +152,29 @@ class SqlAlchemySqlStoreImpl(SqlStore): await session.execute(self.metadata.tables[table].insert(), data) await session.commit() + async def upsert( + self, + table: str, + data: Mapping[str, Any], + conflict_columns: list[str], + update_columns: list[str] | None = None, + ) -> None: + table_obj = self.metadata.tables[table] + dialect_insert = self._get_dialect_insert(table_obj) + insert_stmt = dialect_insert.values(**data) + + if update_columns is None: + update_columns = [col for col in data.keys() if col not in conflict_columns] + + update_mapping = {col: getattr(insert_stmt.excluded, col) for col in update_columns} + conflict_cols = [table_obj.c[col] for col in conflict_columns] + + stmt = insert_stmt.on_conflict_do_update(index_elements=conflict_cols, set_=update_mapping) + + async with self.async_session() as session: + await session.execute(stmt) + await session.commit() + async def fetch_all( self, table: str, @@ -333,9 +357,18 @@ class SqlAlchemySqlStoreImpl(SqlStore): add_column_sql = text(f"ALTER TABLE {table} ADD COLUMN {column_name} {compiled_type}{nullable_clause}") await conn.execute(add_column_sql) - except Exception as e: # If any error occurs during migration, log it but don't fail # The table creation will handle adding the column logger.error(f"Error adding column {column_name} to table {table}: {e}") pass + + def _get_dialect_insert(self, table: Table): + if self._is_sqlite_backend: + from sqlalchemy.dialects.sqlite import insert as sqlite_insert + + return sqlite_insert(table) + else: + from sqlalchemy.dialects.postgresql import insert as pg_insert + + return pg_insert(table) diff --git a/src/llama_stack/providers/utils/sqlstore/sqlstore.py b/src/llama_stack/providers/utils/sqlstore/sqlstore.py index 31801c4ca..9409b7d00 100644 --- a/src/llama_stack/providers/utils/sqlstore/sqlstore.py +++ b/src/llama_stack/providers/utils/sqlstore/sqlstore.py @@ -4,6 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +from threading import Lock from typing import Annotated, cast from pydantic import Field @@ -21,6 +22,8 @@ from .api import SqlStore sql_store_pip_packages = ["sqlalchemy[asyncio]", "aiosqlite", "asyncpg"] _SQLSTORE_BACKENDS: dict[str, StorageBackendConfig] = {} +_SQLSTORE_INSTANCES: dict[str, SqlStore] = {} +_SQLSTORE_LOCKS: dict[str, Lock] = {} SqlStoreConfig = Annotated[ @@ -52,19 +55,34 @@ def sqlstore_impl(reference: SqlStoreReference) -> SqlStore: f"Unknown SQL store backend '{backend_name}'. Registered backends: {sorted(_SQLSTORE_BACKENDS)}" ) - if isinstance(backend_config, SqliteSqlStoreConfig | PostgresSqlStoreConfig): - from .sqlalchemy_sqlstore import SqlAlchemySqlStoreImpl + existing = _SQLSTORE_INSTANCES.get(backend_name) + if existing: + return existing - config = cast(SqliteSqlStoreConfig | PostgresSqlStoreConfig, backend_config).model_copy() - return SqlAlchemySqlStoreImpl(config) - else: - raise ValueError(f"Unknown sqlstore type {backend_config.type}") + lock = _SQLSTORE_LOCKS.setdefault(backend_name, Lock()) + with lock: + existing = _SQLSTORE_INSTANCES.get(backend_name) + if existing: + return existing + + if isinstance(backend_config, SqliteSqlStoreConfig | PostgresSqlStoreConfig): + from .sqlalchemy_sqlstore import SqlAlchemySqlStoreImpl + + config = cast(SqliteSqlStoreConfig | PostgresSqlStoreConfig, backend_config).model_copy() + instance = SqlAlchemySqlStoreImpl(config) + _SQLSTORE_INSTANCES[backend_name] = instance + return instance + else: + raise ValueError(f"Unknown sqlstore type {backend_config.type}") def register_sqlstore_backends(backends: dict[str, StorageBackendConfig]) -> None: """Register the set of available SQL store backends for reference resolution.""" global _SQLSTORE_BACKENDS + global _SQLSTORE_INSTANCES _SQLSTORE_BACKENDS.clear() + _SQLSTORE_INSTANCES.clear() + _SQLSTORE_LOCKS.clear() for name, cfg in backends.items(): _SQLSTORE_BACKENDS[name] = cfg diff --git a/tests/unit/utils/sqlstore/test_sqlstore.py b/tests/unit/utils/sqlstore/test_sqlstore.py index 00669b698..d7ba0dc89 100644 --- a/tests/unit/utils/sqlstore/test_sqlstore.py +++ b/tests/unit/utils/sqlstore/test_sqlstore.py @@ -9,7 +9,7 @@ from tempfile import TemporaryDirectory import pytest -from llama_stack.providers.utils.sqlstore.api import ColumnType +from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType from llama_stack.providers.utils.sqlstore.sqlalchemy_sqlstore import SqlAlchemySqlStoreImpl from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig @@ -65,6 +65,38 @@ async def test_sqlite_sqlstore(): assert result.has_more is False +async def test_sqlstore_upsert_support(): + with TemporaryDirectory() as tmp_dir: + db_path = tmp_dir + "/upsert.db" + store = SqlAlchemySqlStoreImpl(SqliteSqlStoreConfig(db_path=db_path)) + + await store.create_table( + "items", + { + "id": ColumnDefinition(type=ColumnType.STRING, primary_key=True), + "value": ColumnType.STRING, + "updated_at": ColumnType.INTEGER, + }, + ) + + await store.upsert( + table="items", + data={"id": "item_1", "value": "first", "updated_at": 1}, + conflict_columns=["id"], + ) + row = await store.fetch_one("items", {"id": "item_1"}) + assert row == {"id": "item_1", "value": "first", "updated_at": 1} + + await store.upsert( + table="items", + data={"id": "item_1", "value": "second", "updated_at": 2}, + conflict_columns=["id"], + update_columns=["value", "updated_at"], + ) + row = await store.fetch_one("items", {"id": "item_1"}) + assert row == {"id": "item_1", "value": "second", "updated_at": 2} + + async def test_sqlstore_pagination_basic(): """Test basic pagination functionality at the SQL store level.""" with TemporaryDirectory() as tmp_dir: