From c2c5ab622c8a83eba7ce8e3246444a9e95f278ad Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Wed, 12 Nov 2025 11:27:35 -0800 Subject: [PATCH] feat(storage): share sql/kv instances and add upsert support Sql stores now share a single SqlAlchemySqlStoreImpl per backend, and kvstore_impl caches instances per (backend, namespace). This avoids spawning multiple SQLite connections for the same file, reducing lock contention and aligning the cache story for all backends. Added an async upsert API (with SQLite/Postgres dialect inserts) and routed it through AuthorizedSqlStore, then switched conversations and responses to call it. Using native ON CONFLICT DO UPDATE eliminates the insert-then-update retry window that previously caused long WAL lock retries. Introduced an opt-in conversation stress test that mirrors the recorded prompts from test_conversation_multi_turn_and_streaming while fanning them out across many threads. This gives us a fast local way to hammer the conversations/responses sync path when investigating lockups. --- .../core/conversations/conversations.py | 15 +- .../providers/utils/kvstore/kvstore.py | 57 +++++--- .../utils/responses/responses_store.py | 19 +-- .../providers/utils/sqlstore/api.py | 12 ++ .../utils/sqlstore/authorized_sqlstore.py | 17 +++ .../utils/sqlstore/sqlalchemy_sqlstore.py | 39 +++++- .../providers/utils/sqlstore/sqlstore.py | 30 +++- .../responses/test_conversation_stress.py | 128 ++++++++++++++++++ tests/unit/utils/sqlstore/test_sqlstore.py | 34 ++++- 9 files changed, 300 insertions(+), 51 deletions(-) create mode 100644 tests/integration/responses/test_conversation_stress.py diff --git a/src/llama_stack/core/conversations/conversations.py b/src/llama_stack/core/conversations/conversations.py index 951de5e9d..f83834522 100644 --- a/src/llama_stack/core/conversations/conversations.py +++ b/src/llama_stack/core/conversations/conversations.py @@ -203,16 +203,11 @@ class ConversationServiceImpl(Conversations): "item_data": item_dict, } - # TODO: Add support for upsert in sql_store, this will fail first if ID exists and then update - try: - await self.sql_store.insert(table="conversation_items", data=item_record) - except Exception: - # If insert fails due to ID conflict, update existing record - await self.sql_store.update( - table="conversation_items", - data={"created_at": created_at, "item_data": item_dict}, - where={"id": item_id}, - ) + await self.sql_store.upsert( + table="conversation_items", + data=item_record, + conflict_columns=["id"], + ) created_items.append(item_dict) diff --git a/src/llama_stack/providers/utils/kvstore/kvstore.py b/src/llama_stack/providers/utils/kvstore/kvstore.py index eee51e5d9..5b8d77102 100644 --- a/src/llama_stack/providers/utils/kvstore/kvstore.py +++ b/src/llama_stack/providers/utils/kvstore/kvstore.py @@ -11,6 +11,9 @@ from __future__ import annotations +import asyncio +from collections import defaultdict + from llama_stack.core.storage.datatypes import KVStoreReference, StorageBackendConfig, StorageBackendType from .api import KVStore @@ -53,45 +56,63 @@ class InmemoryKVStoreImpl(KVStore): _KVSTORE_BACKENDS: dict[str, KVStoreConfig] = {} +_KVSTORE_INSTANCES: dict[tuple[str, str], KVStore] = {} +_KVSTORE_LOCKS: defaultdict[tuple[str, str], asyncio.Lock] = defaultdict(asyncio.Lock) def register_kvstore_backends(backends: dict[str, StorageBackendConfig]) -> None: """Register the set of available KV store backends for reference resolution.""" global _KVSTORE_BACKENDS + global _KVSTORE_INSTANCES + global _KVSTORE_LOCKS _KVSTORE_BACKENDS.clear() + _KVSTORE_INSTANCES.clear() + _KVSTORE_LOCKS.clear() for name, cfg in backends.items(): _KVSTORE_BACKENDS[name] = cfg async def kvstore_impl(reference: KVStoreReference) -> KVStore: backend_name = reference.backend + cache_key = (backend_name, reference.namespace) + + existing = _KVSTORE_INSTANCES.get(cache_key) + if existing: + return existing backend_config = _KVSTORE_BACKENDS.get(backend_name) if backend_config is None: raise ValueError(f"Unknown KVStore backend '{backend_name}'. Registered backends: {sorted(_KVSTORE_BACKENDS)}") - config = backend_config.model_copy() - config.namespace = reference.namespace + lock = _KVSTORE_LOCKS[cache_key] + async with lock: + existing = _KVSTORE_INSTANCES.get(cache_key) + if existing: + return existing - if config.type == StorageBackendType.KV_REDIS.value: - from .redis import RedisKVStoreImpl + config = backend_config.model_copy() + config.namespace = reference.namespace - impl = RedisKVStoreImpl(config) - elif config.type == StorageBackendType.KV_SQLITE.value: - from .sqlite import SqliteKVStoreImpl + if config.type == StorageBackendType.KV_REDIS.value: + from .redis import RedisKVStoreImpl - impl = SqliteKVStoreImpl(config) - elif config.type == StorageBackendType.KV_POSTGRES.value: - from .postgres import PostgresKVStoreImpl + impl = RedisKVStoreImpl(config) + elif config.type == StorageBackendType.KV_SQLITE.value: + from .sqlite import SqliteKVStoreImpl - impl = PostgresKVStoreImpl(config) - elif config.type == StorageBackendType.KV_MONGODB.value: - from .mongodb import MongoDBKVStoreImpl + impl = SqliteKVStoreImpl(config) + elif config.type == StorageBackendType.KV_POSTGRES.value: + from .postgres import PostgresKVStoreImpl - impl = MongoDBKVStoreImpl(config) - else: - raise ValueError(f"Unknown kvstore type {config.type}") + impl = PostgresKVStoreImpl(config) + elif config.type == StorageBackendType.KV_MONGODB.value: + from .mongodb import MongoDBKVStoreImpl - await impl.initialize() - return impl + impl = MongoDBKVStoreImpl(config) + else: + raise ValueError(f"Unknown kvstore type {config.type}") + + await impl.initialize() + _KVSTORE_INSTANCES[cache_key] = impl + return impl diff --git a/src/llama_stack/providers/utils/responses/responses_store.py b/src/llama_stack/providers/utils/responses/responses_store.py index f5024a9ed..fdca8ddee 100644 --- a/src/llama_stack/providers/utils/responses/responses_store.py +++ b/src/llama_stack/providers/utils/responses/responses_store.py @@ -252,19 +252,12 @@ class ResponsesStore: # Serialize messages to dict format for JSON storage messages_data = [msg.model_dump() for msg in messages] - # Upsert: try insert first, update if exists - try: - await self.sql_store.insert( - table="conversation_messages", - data={"conversation_id": conversation_id, "messages": messages_data}, - ) - except Exception: - # If insert fails due to ID conflict, update existing record - await self.sql_store.update( - table="conversation_messages", - data={"messages": messages_data}, - where={"conversation_id": conversation_id}, - ) + await self.sql_store.upsert( + table="conversation_messages", + data={"conversation_id": conversation_id, "messages": messages_data}, + conflict_columns=["conversation_id"], + update_columns=["messages"], + ) logger.debug(f"Stored {len(messages)} messages for conversation {conversation_id}") diff --git a/src/llama_stack/providers/utils/sqlstore/api.py b/src/llama_stack/providers/utils/sqlstore/api.py index a61fd1090..bcd224234 100644 --- a/src/llama_stack/providers/utils/sqlstore/api.py +++ b/src/llama_stack/providers/utils/sqlstore/api.py @@ -47,6 +47,18 @@ class SqlStore(Protocol): """ pass + async def upsert( + self, + table: str, + data: Mapping[str, Any], + conflict_columns: list[str], + update_columns: list[str] | None = None, + ) -> None: + """ + Insert a row and update specified columns when conflicts occur. + """ + pass + async def fetch_all( self, table: str, diff --git a/src/llama_stack/providers/utils/sqlstore/authorized_sqlstore.py b/src/llama_stack/providers/utils/sqlstore/authorized_sqlstore.py index eb2d9a491..ba95dd120 100644 --- a/src/llama_stack/providers/utils/sqlstore/authorized_sqlstore.py +++ b/src/llama_stack/providers/utils/sqlstore/authorized_sqlstore.py @@ -129,6 +129,23 @@ class AuthorizedSqlStore: enhanced_data = [_enhance_item_with_access_control(item, current_user) for item in data] await self.sql_store.insert(table, enhanced_data) + async def upsert( + self, + table: str, + data: Mapping[str, Any], + conflict_columns: list[str], + update_columns: list[str] | None = None, + ) -> None: + """Upsert a row with automatic access control attribute capture.""" + current_user = get_authenticated_user() + enhanced_data = _enhance_item_with_access_control(data, current_user) + await self.sql_store.upsert( + table=table, + data=enhanced_data, + conflict_columns=conflict_columns, + update_columns=update_columns, + ) + async def fetch_all( self, table: str, diff --git a/src/llama_stack/providers/utils/sqlstore/sqlalchemy_sqlstore.py b/src/llama_stack/providers/utils/sqlstore/sqlalchemy_sqlstore.py index 356f49ed1..cfc3131f4 100644 --- a/src/llama_stack/providers/utils/sqlstore/sqlalchemy_sqlstore.py +++ b/src/llama_stack/providers/utils/sqlstore/sqlalchemy_sqlstore.py @@ -72,13 +72,14 @@ def _build_where_expr(column: ColumnElement, value: Any) -> ColumnElement: class SqlAlchemySqlStoreImpl(SqlStore): def __init__(self, config: SqlAlchemySqlStoreConfig): self.config = config + self._is_sqlite_backend = "sqlite" in self.config.engine_str self.async_session = async_sessionmaker(self.create_engine()) self.metadata = MetaData() def create_engine(self) -> AsyncEngine: # Configure connection args for better concurrency support connect_args = {} - if "sqlite" in self.config.engine_str: + if self._is_sqlite_backend: # SQLite-specific optimizations for concurrent access # With WAL mode, most locks resolve in milliseconds, but allow up to 5s for edge cases connect_args["timeout"] = 5.0 @@ -91,7 +92,7 @@ class SqlAlchemySqlStoreImpl(SqlStore): ) # Enable WAL mode for SQLite to support concurrent readers and writers - if "sqlite" in self.config.engine_str: + if self._is_sqlite_backend: @event.listens_for(engine.sync_engine, "connect") def set_sqlite_pragma(dbapi_conn, connection_record): @@ -151,6 +152,29 @@ class SqlAlchemySqlStoreImpl(SqlStore): await session.execute(self.metadata.tables[table].insert(), data) await session.commit() + async def upsert( + self, + table: str, + data: Mapping[str, Any], + conflict_columns: list[str], + update_columns: list[str] | None = None, + ) -> None: + table_obj = self.metadata.tables[table] + dialect_insert = self._get_dialect_insert(table_obj) + insert_stmt = dialect_insert.values(**data) + + if update_columns is None: + update_columns = [col for col in data.keys() if col not in conflict_columns] + + update_mapping = {col: getattr(insert_stmt.excluded, col) for col in update_columns} + conflict_cols = [table_obj.c[col] for col in conflict_columns] + + stmt = insert_stmt.on_conflict_do_update(index_elements=conflict_cols, set_=update_mapping) + + async with self.async_session() as session: + await session.execute(stmt) + await session.commit() + async def fetch_all( self, table: str, @@ -333,9 +357,18 @@ class SqlAlchemySqlStoreImpl(SqlStore): add_column_sql = text(f"ALTER TABLE {table} ADD COLUMN {column_name} {compiled_type}{nullable_clause}") await conn.execute(add_column_sql) - except Exception as e: # If any error occurs during migration, log it but don't fail # The table creation will handle adding the column logger.error(f"Error adding column {column_name} to table {table}: {e}") pass + + def _get_dialect_insert(self, table: Table): + if self._is_sqlite_backend: + from sqlalchemy.dialects.sqlite import insert as sqlite_insert + + return sqlite_insert(table) + else: + from sqlalchemy.dialects.postgresql import insert as pg_insert + + return pg_insert(table) diff --git a/src/llama_stack/providers/utils/sqlstore/sqlstore.py b/src/llama_stack/providers/utils/sqlstore/sqlstore.py index 31801c4ca..9409b7d00 100644 --- a/src/llama_stack/providers/utils/sqlstore/sqlstore.py +++ b/src/llama_stack/providers/utils/sqlstore/sqlstore.py @@ -4,6 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +from threading import Lock from typing import Annotated, cast from pydantic import Field @@ -21,6 +22,8 @@ from .api import SqlStore sql_store_pip_packages = ["sqlalchemy[asyncio]", "aiosqlite", "asyncpg"] _SQLSTORE_BACKENDS: dict[str, StorageBackendConfig] = {} +_SQLSTORE_INSTANCES: dict[str, SqlStore] = {} +_SQLSTORE_LOCKS: dict[str, Lock] = {} SqlStoreConfig = Annotated[ @@ -52,19 +55,34 @@ def sqlstore_impl(reference: SqlStoreReference) -> SqlStore: f"Unknown SQL store backend '{backend_name}'. Registered backends: {sorted(_SQLSTORE_BACKENDS)}" ) - if isinstance(backend_config, SqliteSqlStoreConfig | PostgresSqlStoreConfig): - from .sqlalchemy_sqlstore import SqlAlchemySqlStoreImpl + existing = _SQLSTORE_INSTANCES.get(backend_name) + if existing: + return existing - config = cast(SqliteSqlStoreConfig | PostgresSqlStoreConfig, backend_config).model_copy() - return SqlAlchemySqlStoreImpl(config) - else: - raise ValueError(f"Unknown sqlstore type {backend_config.type}") + lock = _SQLSTORE_LOCKS.setdefault(backend_name, Lock()) + with lock: + existing = _SQLSTORE_INSTANCES.get(backend_name) + if existing: + return existing + + if isinstance(backend_config, SqliteSqlStoreConfig | PostgresSqlStoreConfig): + from .sqlalchemy_sqlstore import SqlAlchemySqlStoreImpl + + config = cast(SqliteSqlStoreConfig | PostgresSqlStoreConfig, backend_config).model_copy() + instance = SqlAlchemySqlStoreImpl(config) + _SQLSTORE_INSTANCES[backend_name] = instance + return instance + else: + raise ValueError(f"Unknown sqlstore type {backend_config.type}") def register_sqlstore_backends(backends: dict[str, StorageBackendConfig]) -> None: """Register the set of available SQL store backends for reference resolution.""" global _SQLSTORE_BACKENDS + global _SQLSTORE_INSTANCES _SQLSTORE_BACKENDS.clear() + _SQLSTORE_INSTANCES.clear() + _SQLSTORE_LOCKS.clear() for name, cfg in backends.items(): _SQLSTORE_BACKENDS[name] = cfg diff --git a/tests/integration/responses/test_conversation_stress.py b/tests/integration/responses/test_conversation_stress.py new file mode 100644 index 000000000..fa2471561 --- /dev/null +++ b/tests/integration/responses/test_conversation_stress.py @@ -0,0 +1,128 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. +""" +Opt-in concurrency stress tests for the Responses↔Conversations sync path. + +By default this module is skipped to avoid lengthening CI. Set the environment +variable LLAMA_STACK_ENABLE_CONCURRENCY_TESTS=1 to run it locally. Optional +knobs: + + LLAMA_STACK_CONCURRENCY_WORKERS (default: 4) + LLAMA_STACK_CONCURRENCY_TURNS (default: 3) +""" + +from __future__ import annotations + +import os +from concurrent.futures import ThreadPoolExecutor, as_completed + +import pytest + +from llama_stack.core.testing_context import reset_test_context, set_test_context + +if not os.getenv("LLAMA_STACK_ENABLE_CONCURRENCY_TESTS"): + pytest.skip("Set LLAMA_STACK_ENABLE_CONCURRENCY_TESTS=1 to run stress tests", allow_module_level=True) + + +CONCURRENCY_WORKERS = int(os.getenv("LLAMA_STACK_CONCURRENCY_WORKERS", "4")) +TURNS_PER_WORKER = int(os.getenv("LLAMA_STACK_CONCURRENCY_TURNS", "1")) + +# Reuse the exact prompts from test_conversation_responses to hit existing recordings. +FIRST_TURN_PROMPT = "Say hello" +STREAMING_PROMPT = "Say goodbye" +RECORDED_TEST_NODE_ID = ( + "tests/integration/responses/test_conversation_responses.py::TestConversationResponses::" + "test_conversation_multi_turn_and_streaming[txt=openai/gpt-4o]" +) + + +@pytest.mark.integration +def test_conversation_streaming_multi_client(openai_client, text_model_id): + """Run many copies of the conversation streaming flow concurrently (matching recorded prompts).""" + + if CONCURRENCY_WORKERS < 2: + pytest.skip("Need at least 2 workers for concurrency stress") + + def stream_single_turn(conversation_id: str, worker_idx: int, turn_idx: int): + """Send the streaming turn and ensure completion.""" + response_stream = openai_client.responses.create( + model=text_model_id, + input=[{"role": "user", "content": STREAMING_PROMPT}], + conversation=conversation_id, + stream=True, + ) + + final_response = None + failed_error: str | None = None + for chunk in response_stream: + chunk_type = getattr(chunk, "type", None) + if chunk_type is None and isinstance(chunk, dict): + chunk_type = chunk.get("type") + + chunk_response = getattr(chunk, "response", None) + if chunk_response is None and isinstance(chunk, dict): + chunk_response = chunk.get("response") + + if chunk_type in {"response.completed", "response.incomplete"}: + final_response = chunk_response + break + if chunk_type == "response.failed": + error_obj = None + if chunk_response is not None: + error_obj = getattr(chunk_response, "error", None) + if error_obj is None and isinstance(chunk_response, dict): + error_obj = chunk_response.get("error") + if isinstance(error_obj, dict): + failed_error = error_obj.get("message", "response.failed") + elif hasattr(error_obj, "message"): + failed_error = error_obj.message # type: ignore[assignment] + else: + failed_error = "response.failed event without error" + break + + if failed_error: + raise RuntimeError(f"Worker {worker_idx} turn {turn_idx} failed: {failed_error}") + if final_response is None: + raise RuntimeError(f"Worker {worker_idx} turn {turn_idx} did not complete") + + def worker(worker_idx: int) -> list[str]: + token = set_test_context(RECORDED_TEST_NODE_ID) + conversation_ids: list[str] = [] + try: + for turn_idx in range(TURNS_PER_WORKER): + conversation = openai_client.conversations.create() + assert conversation.id.startswith("conv_") + conversation_ids.append(conversation.id) + + openai_client.responses.create( + model=text_model_id, + input=[{"role": "user", "content": FIRST_TURN_PROMPT}], + conversation=conversation.id, + ) + + stream_single_turn(conversation.id, worker_idx, turn_idx) + + return conversation_ids + finally: + reset_test_context(token) + + all_conversation_ids: list[str] = [] + futures = [] + with ThreadPoolExecutor(max_workers=CONCURRENCY_WORKERS) as executor: + for worker_idx in range(CONCURRENCY_WORKERS): + futures.append(executor.submit(worker, worker_idx)) + + for future in as_completed(futures): + all_conversation_ids.extend(future.result()) + + expected_conversations = CONCURRENCY_WORKERS * TURNS_PER_WORKER + assert len(all_conversation_ids) == expected_conversations + + for conversation_id in all_conversation_ids: + conversation_items = openai_client.conversations.items.list(conversation_id) + assert len(conversation_items.data) >= 4, ( + f"Conversation {conversation_id} missing expected items: {len(conversation_items.data)}" + ) diff --git a/tests/unit/utils/sqlstore/test_sqlstore.py b/tests/unit/utils/sqlstore/test_sqlstore.py index 00669b698..d7ba0dc89 100644 --- a/tests/unit/utils/sqlstore/test_sqlstore.py +++ b/tests/unit/utils/sqlstore/test_sqlstore.py @@ -9,7 +9,7 @@ from tempfile import TemporaryDirectory import pytest -from llama_stack.providers.utils.sqlstore.api import ColumnType +from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType from llama_stack.providers.utils.sqlstore.sqlalchemy_sqlstore import SqlAlchemySqlStoreImpl from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig @@ -65,6 +65,38 @@ async def test_sqlite_sqlstore(): assert result.has_more is False +async def test_sqlstore_upsert_support(): + with TemporaryDirectory() as tmp_dir: + db_path = tmp_dir + "/upsert.db" + store = SqlAlchemySqlStoreImpl(SqliteSqlStoreConfig(db_path=db_path)) + + await store.create_table( + "items", + { + "id": ColumnDefinition(type=ColumnType.STRING, primary_key=True), + "value": ColumnType.STRING, + "updated_at": ColumnType.INTEGER, + }, + ) + + await store.upsert( + table="items", + data={"id": "item_1", "value": "first", "updated_at": 1}, + conflict_columns=["id"], + ) + row = await store.fetch_one("items", {"id": "item_1"}) + assert row == {"id": "item_1", "value": "first", "updated_at": 1} + + await store.upsert( + table="items", + data={"id": "item_1", "value": "second", "updated_at": 2}, + conflict_columns=["id"], + update_columns=["value", "updated_at"], + ) + row = await store.fetch_one("items", {"id": "item_1"}) + assert row == {"id": "item_1", "value": "second", "updated_at": 2} + + async def test_sqlstore_pagination_basic(): """Test basic pagination functionality at the SQL store level.""" with TemporaryDirectory() as tmp_dir: