diff --git a/src/llama_stack/core/conversations/conversations.py b/src/llama_stack/core/conversations/conversations.py index 951de5e9d..f83834522 100644 --- a/src/llama_stack/core/conversations/conversations.py +++ b/src/llama_stack/core/conversations/conversations.py @@ -203,16 +203,11 @@ class ConversationServiceImpl(Conversations): "item_data": item_dict, } - # TODO: Add support for upsert in sql_store, this will fail first if ID exists and then update - try: - await self.sql_store.insert(table="conversation_items", data=item_record) - except Exception: - # If insert fails due to ID conflict, update existing record - await self.sql_store.update( - table="conversation_items", - data={"created_at": created_at, "item_data": item_dict}, - where={"id": item_id}, - ) + await self.sql_store.upsert( + table="conversation_items", + data=item_record, + conflict_columns=["id"], + ) created_items.append(item_dict) diff --git a/src/llama_stack/providers/utils/kvstore/kvstore.py b/src/llama_stack/providers/utils/kvstore/kvstore.py index eee51e5d9..5b8d77102 100644 --- a/src/llama_stack/providers/utils/kvstore/kvstore.py +++ b/src/llama_stack/providers/utils/kvstore/kvstore.py @@ -11,6 +11,9 @@ from __future__ import annotations +import asyncio +from collections import defaultdict + from llama_stack.core.storage.datatypes import KVStoreReference, StorageBackendConfig, StorageBackendType from .api import KVStore @@ -53,45 +56,63 @@ class InmemoryKVStoreImpl(KVStore): _KVSTORE_BACKENDS: dict[str, KVStoreConfig] = {} +_KVSTORE_INSTANCES: dict[tuple[str, str], KVStore] = {} +_KVSTORE_LOCKS: defaultdict[tuple[str, str], asyncio.Lock] = defaultdict(asyncio.Lock) def register_kvstore_backends(backends: dict[str, StorageBackendConfig]) -> None: """Register the set of available KV store backends for reference resolution.""" global _KVSTORE_BACKENDS + global _KVSTORE_INSTANCES + global _KVSTORE_LOCKS _KVSTORE_BACKENDS.clear() + _KVSTORE_INSTANCES.clear() + _KVSTORE_LOCKS.clear() for name, cfg in backends.items(): _KVSTORE_BACKENDS[name] = cfg async def kvstore_impl(reference: KVStoreReference) -> KVStore: backend_name = reference.backend + cache_key = (backend_name, reference.namespace) + + existing = _KVSTORE_INSTANCES.get(cache_key) + if existing: + return existing backend_config = _KVSTORE_BACKENDS.get(backend_name) if backend_config is None: raise ValueError(f"Unknown KVStore backend '{backend_name}'. Registered backends: {sorted(_KVSTORE_BACKENDS)}") - config = backend_config.model_copy() - config.namespace = reference.namespace + lock = _KVSTORE_LOCKS[cache_key] + async with lock: + existing = _KVSTORE_INSTANCES.get(cache_key) + if existing: + return existing - if config.type == StorageBackendType.KV_REDIS.value: - from .redis import RedisKVStoreImpl + config = backend_config.model_copy() + config.namespace = reference.namespace - impl = RedisKVStoreImpl(config) - elif config.type == StorageBackendType.KV_SQLITE.value: - from .sqlite import SqliteKVStoreImpl + if config.type == StorageBackendType.KV_REDIS.value: + from .redis import RedisKVStoreImpl - impl = SqliteKVStoreImpl(config) - elif config.type == StorageBackendType.KV_POSTGRES.value: - from .postgres import PostgresKVStoreImpl + impl = RedisKVStoreImpl(config) + elif config.type == StorageBackendType.KV_SQLITE.value: + from .sqlite import SqliteKVStoreImpl - impl = PostgresKVStoreImpl(config) - elif config.type == StorageBackendType.KV_MONGODB.value: - from .mongodb import MongoDBKVStoreImpl + impl = SqliteKVStoreImpl(config) + elif config.type == StorageBackendType.KV_POSTGRES.value: + from .postgres import PostgresKVStoreImpl - impl = MongoDBKVStoreImpl(config) - else: - raise ValueError(f"Unknown kvstore type {config.type}") + impl = PostgresKVStoreImpl(config) + elif config.type == StorageBackendType.KV_MONGODB.value: + from .mongodb import MongoDBKVStoreImpl - await impl.initialize() - return impl + impl = MongoDBKVStoreImpl(config) + else: + raise ValueError(f"Unknown kvstore type {config.type}") + + await impl.initialize() + _KVSTORE_INSTANCES[cache_key] = impl + return impl diff --git a/src/llama_stack/providers/utils/responses/responses_store.py b/src/llama_stack/providers/utils/responses/responses_store.py index f5024a9ed..fdca8ddee 100644 --- a/src/llama_stack/providers/utils/responses/responses_store.py +++ b/src/llama_stack/providers/utils/responses/responses_store.py @@ -252,19 +252,12 @@ class ResponsesStore: # Serialize messages to dict format for JSON storage messages_data = [msg.model_dump() for msg in messages] - # Upsert: try insert first, update if exists - try: - await self.sql_store.insert( - table="conversation_messages", - data={"conversation_id": conversation_id, "messages": messages_data}, - ) - except Exception: - # If insert fails due to ID conflict, update existing record - await self.sql_store.update( - table="conversation_messages", - data={"messages": messages_data}, - where={"conversation_id": conversation_id}, - ) + await self.sql_store.upsert( + table="conversation_messages", + data={"conversation_id": conversation_id, "messages": messages_data}, + conflict_columns=["conversation_id"], + update_columns=["messages"], + ) logger.debug(f"Stored {len(messages)} messages for conversation {conversation_id}") diff --git a/src/llama_stack/providers/utils/sqlstore/api.py b/src/llama_stack/providers/utils/sqlstore/api.py index a61fd1090..bcd224234 100644 --- a/src/llama_stack/providers/utils/sqlstore/api.py +++ b/src/llama_stack/providers/utils/sqlstore/api.py @@ -47,6 +47,18 @@ class SqlStore(Protocol): """ pass + async def upsert( + self, + table: str, + data: Mapping[str, Any], + conflict_columns: list[str], + update_columns: list[str] | None = None, + ) -> None: + """ + Insert a row and update specified columns when conflicts occur. + """ + pass + async def fetch_all( self, table: str, diff --git a/src/llama_stack/providers/utils/sqlstore/authorized_sqlstore.py b/src/llama_stack/providers/utils/sqlstore/authorized_sqlstore.py index eb2d9a491..ba95dd120 100644 --- a/src/llama_stack/providers/utils/sqlstore/authorized_sqlstore.py +++ b/src/llama_stack/providers/utils/sqlstore/authorized_sqlstore.py @@ -129,6 +129,23 @@ class AuthorizedSqlStore: enhanced_data = [_enhance_item_with_access_control(item, current_user) for item in data] await self.sql_store.insert(table, enhanced_data) + async def upsert( + self, + table: str, + data: Mapping[str, Any], + conflict_columns: list[str], + update_columns: list[str] | None = None, + ) -> None: + """Upsert a row with automatic access control attribute capture.""" + current_user = get_authenticated_user() + enhanced_data = _enhance_item_with_access_control(data, current_user) + await self.sql_store.upsert( + table=table, + data=enhanced_data, + conflict_columns=conflict_columns, + update_columns=update_columns, + ) + async def fetch_all( self, table: str, diff --git a/src/llama_stack/providers/utils/sqlstore/sqlalchemy_sqlstore.py b/src/llama_stack/providers/utils/sqlstore/sqlalchemy_sqlstore.py index 356f49ed1..cfc3131f4 100644 --- a/src/llama_stack/providers/utils/sqlstore/sqlalchemy_sqlstore.py +++ b/src/llama_stack/providers/utils/sqlstore/sqlalchemy_sqlstore.py @@ -72,13 +72,14 @@ def _build_where_expr(column: ColumnElement, value: Any) -> ColumnElement: class SqlAlchemySqlStoreImpl(SqlStore): def __init__(self, config: SqlAlchemySqlStoreConfig): self.config = config + self._is_sqlite_backend = "sqlite" in self.config.engine_str self.async_session = async_sessionmaker(self.create_engine()) self.metadata = MetaData() def create_engine(self) -> AsyncEngine: # Configure connection args for better concurrency support connect_args = {} - if "sqlite" in self.config.engine_str: + if self._is_sqlite_backend: # SQLite-specific optimizations for concurrent access # With WAL mode, most locks resolve in milliseconds, but allow up to 5s for edge cases connect_args["timeout"] = 5.0 @@ -91,7 +92,7 @@ class SqlAlchemySqlStoreImpl(SqlStore): ) # Enable WAL mode for SQLite to support concurrent readers and writers - if "sqlite" in self.config.engine_str: + if self._is_sqlite_backend: @event.listens_for(engine.sync_engine, "connect") def set_sqlite_pragma(dbapi_conn, connection_record): @@ -151,6 +152,29 @@ class SqlAlchemySqlStoreImpl(SqlStore): await session.execute(self.metadata.tables[table].insert(), data) await session.commit() + async def upsert( + self, + table: str, + data: Mapping[str, Any], + conflict_columns: list[str], + update_columns: list[str] | None = None, + ) -> None: + table_obj = self.metadata.tables[table] + dialect_insert = self._get_dialect_insert(table_obj) + insert_stmt = dialect_insert.values(**data) + + if update_columns is None: + update_columns = [col for col in data.keys() if col not in conflict_columns] + + update_mapping = {col: getattr(insert_stmt.excluded, col) for col in update_columns} + conflict_cols = [table_obj.c[col] for col in conflict_columns] + + stmt = insert_stmt.on_conflict_do_update(index_elements=conflict_cols, set_=update_mapping) + + async with self.async_session() as session: + await session.execute(stmt) + await session.commit() + async def fetch_all( self, table: str, @@ -333,9 +357,18 @@ class SqlAlchemySqlStoreImpl(SqlStore): add_column_sql = text(f"ALTER TABLE {table} ADD COLUMN {column_name} {compiled_type}{nullable_clause}") await conn.execute(add_column_sql) - except Exception as e: # If any error occurs during migration, log it but don't fail # The table creation will handle adding the column logger.error(f"Error adding column {column_name} to table {table}: {e}") pass + + def _get_dialect_insert(self, table: Table): + if self._is_sqlite_backend: + from sqlalchemy.dialects.sqlite import insert as sqlite_insert + + return sqlite_insert(table) + else: + from sqlalchemy.dialects.postgresql import insert as pg_insert + + return pg_insert(table) diff --git a/src/llama_stack/providers/utils/sqlstore/sqlstore.py b/src/llama_stack/providers/utils/sqlstore/sqlstore.py index 31801c4ca..9409b7d00 100644 --- a/src/llama_stack/providers/utils/sqlstore/sqlstore.py +++ b/src/llama_stack/providers/utils/sqlstore/sqlstore.py @@ -4,6 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +from threading import Lock from typing import Annotated, cast from pydantic import Field @@ -21,6 +22,8 @@ from .api import SqlStore sql_store_pip_packages = ["sqlalchemy[asyncio]", "aiosqlite", "asyncpg"] _SQLSTORE_BACKENDS: dict[str, StorageBackendConfig] = {} +_SQLSTORE_INSTANCES: dict[str, SqlStore] = {} +_SQLSTORE_LOCKS: dict[str, Lock] = {} SqlStoreConfig = Annotated[ @@ -52,19 +55,34 @@ def sqlstore_impl(reference: SqlStoreReference) -> SqlStore: f"Unknown SQL store backend '{backend_name}'. Registered backends: {sorted(_SQLSTORE_BACKENDS)}" ) - if isinstance(backend_config, SqliteSqlStoreConfig | PostgresSqlStoreConfig): - from .sqlalchemy_sqlstore import SqlAlchemySqlStoreImpl + existing = _SQLSTORE_INSTANCES.get(backend_name) + if existing: + return existing - config = cast(SqliteSqlStoreConfig | PostgresSqlStoreConfig, backend_config).model_copy() - return SqlAlchemySqlStoreImpl(config) - else: - raise ValueError(f"Unknown sqlstore type {backend_config.type}") + lock = _SQLSTORE_LOCKS.setdefault(backend_name, Lock()) + with lock: + existing = _SQLSTORE_INSTANCES.get(backend_name) + if existing: + return existing + + if isinstance(backend_config, SqliteSqlStoreConfig | PostgresSqlStoreConfig): + from .sqlalchemy_sqlstore import SqlAlchemySqlStoreImpl + + config = cast(SqliteSqlStoreConfig | PostgresSqlStoreConfig, backend_config).model_copy() + instance = SqlAlchemySqlStoreImpl(config) + _SQLSTORE_INSTANCES[backend_name] = instance + return instance + else: + raise ValueError(f"Unknown sqlstore type {backend_config.type}") def register_sqlstore_backends(backends: dict[str, StorageBackendConfig]) -> None: """Register the set of available SQL store backends for reference resolution.""" global _SQLSTORE_BACKENDS + global _SQLSTORE_INSTANCES _SQLSTORE_BACKENDS.clear() + _SQLSTORE_INSTANCES.clear() + _SQLSTORE_LOCKS.clear() for name, cfg in backends.items(): _SQLSTORE_BACKENDS[name] = cfg diff --git a/tests/integration/responses/test_conversation_stress.py b/tests/integration/responses/test_conversation_stress.py new file mode 100644 index 000000000..fa2471561 --- /dev/null +++ b/tests/integration/responses/test_conversation_stress.py @@ -0,0 +1,128 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. +""" +Opt-in concurrency stress tests for the Responses↔Conversations sync path. + +By default this module is skipped to avoid lengthening CI. Set the environment +variable LLAMA_STACK_ENABLE_CONCURRENCY_TESTS=1 to run it locally. Optional +knobs: + + LLAMA_STACK_CONCURRENCY_WORKERS (default: 4) + LLAMA_STACK_CONCURRENCY_TURNS (default: 3) +""" + +from __future__ import annotations + +import os +from concurrent.futures import ThreadPoolExecutor, as_completed + +import pytest + +from llama_stack.core.testing_context import reset_test_context, set_test_context + +if not os.getenv("LLAMA_STACK_ENABLE_CONCURRENCY_TESTS"): + pytest.skip("Set LLAMA_STACK_ENABLE_CONCURRENCY_TESTS=1 to run stress tests", allow_module_level=True) + + +CONCURRENCY_WORKERS = int(os.getenv("LLAMA_STACK_CONCURRENCY_WORKERS", "4")) +TURNS_PER_WORKER = int(os.getenv("LLAMA_STACK_CONCURRENCY_TURNS", "1")) + +# Reuse the exact prompts from test_conversation_responses to hit existing recordings. +FIRST_TURN_PROMPT = "Say hello" +STREAMING_PROMPT = "Say goodbye" +RECORDED_TEST_NODE_ID = ( + "tests/integration/responses/test_conversation_responses.py::TestConversationResponses::" + "test_conversation_multi_turn_and_streaming[txt=openai/gpt-4o]" +) + + +@pytest.mark.integration +def test_conversation_streaming_multi_client(openai_client, text_model_id): + """Run many copies of the conversation streaming flow concurrently (matching recorded prompts).""" + + if CONCURRENCY_WORKERS < 2: + pytest.skip("Need at least 2 workers for concurrency stress") + + def stream_single_turn(conversation_id: str, worker_idx: int, turn_idx: int): + """Send the streaming turn and ensure completion.""" + response_stream = openai_client.responses.create( + model=text_model_id, + input=[{"role": "user", "content": STREAMING_PROMPT}], + conversation=conversation_id, + stream=True, + ) + + final_response = None + failed_error: str | None = None + for chunk in response_stream: + chunk_type = getattr(chunk, "type", None) + if chunk_type is None and isinstance(chunk, dict): + chunk_type = chunk.get("type") + + chunk_response = getattr(chunk, "response", None) + if chunk_response is None and isinstance(chunk, dict): + chunk_response = chunk.get("response") + + if chunk_type in {"response.completed", "response.incomplete"}: + final_response = chunk_response + break + if chunk_type == "response.failed": + error_obj = None + if chunk_response is not None: + error_obj = getattr(chunk_response, "error", None) + if error_obj is None and isinstance(chunk_response, dict): + error_obj = chunk_response.get("error") + if isinstance(error_obj, dict): + failed_error = error_obj.get("message", "response.failed") + elif hasattr(error_obj, "message"): + failed_error = error_obj.message # type: ignore[assignment] + else: + failed_error = "response.failed event without error" + break + + if failed_error: + raise RuntimeError(f"Worker {worker_idx} turn {turn_idx} failed: {failed_error}") + if final_response is None: + raise RuntimeError(f"Worker {worker_idx} turn {turn_idx} did not complete") + + def worker(worker_idx: int) -> list[str]: + token = set_test_context(RECORDED_TEST_NODE_ID) + conversation_ids: list[str] = [] + try: + for turn_idx in range(TURNS_PER_WORKER): + conversation = openai_client.conversations.create() + assert conversation.id.startswith("conv_") + conversation_ids.append(conversation.id) + + openai_client.responses.create( + model=text_model_id, + input=[{"role": "user", "content": FIRST_TURN_PROMPT}], + conversation=conversation.id, + ) + + stream_single_turn(conversation.id, worker_idx, turn_idx) + + return conversation_ids + finally: + reset_test_context(token) + + all_conversation_ids: list[str] = [] + futures = [] + with ThreadPoolExecutor(max_workers=CONCURRENCY_WORKERS) as executor: + for worker_idx in range(CONCURRENCY_WORKERS): + futures.append(executor.submit(worker, worker_idx)) + + for future in as_completed(futures): + all_conversation_ids.extend(future.result()) + + expected_conversations = CONCURRENCY_WORKERS * TURNS_PER_WORKER + assert len(all_conversation_ids) == expected_conversations + + for conversation_id in all_conversation_ids: + conversation_items = openai_client.conversations.items.list(conversation_id) + assert len(conversation_items.data) >= 4, ( + f"Conversation {conversation_id} missing expected items: {len(conversation_items.data)}" + ) diff --git a/tests/unit/utils/sqlstore/test_sqlstore.py b/tests/unit/utils/sqlstore/test_sqlstore.py index 00669b698..d7ba0dc89 100644 --- a/tests/unit/utils/sqlstore/test_sqlstore.py +++ b/tests/unit/utils/sqlstore/test_sqlstore.py @@ -9,7 +9,7 @@ from tempfile import TemporaryDirectory import pytest -from llama_stack.providers.utils.sqlstore.api import ColumnType +from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType from llama_stack.providers.utils.sqlstore.sqlalchemy_sqlstore import SqlAlchemySqlStoreImpl from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig @@ -65,6 +65,38 @@ async def test_sqlite_sqlstore(): assert result.has_more is False +async def test_sqlstore_upsert_support(): + with TemporaryDirectory() as tmp_dir: + db_path = tmp_dir + "/upsert.db" + store = SqlAlchemySqlStoreImpl(SqliteSqlStoreConfig(db_path=db_path)) + + await store.create_table( + "items", + { + "id": ColumnDefinition(type=ColumnType.STRING, primary_key=True), + "value": ColumnType.STRING, + "updated_at": ColumnType.INTEGER, + }, + ) + + await store.upsert( + table="items", + data={"id": "item_1", "value": "first", "updated_at": 1}, + conflict_columns=["id"], + ) + row = await store.fetch_one("items", {"id": "item_1"}) + assert row == {"id": "item_1", "value": "first", "updated_at": 1} + + await store.upsert( + table="items", + data={"id": "item_1", "value": "second", "updated_at": 2}, + conflict_columns=["id"], + update_columns=["value", "updated_at"], + ) + row = await store.fetch_one("items", {"id": "item_1"}) + assert row == {"id": "item_1", "value": "second", "updated_at": 2} + + async def test_sqlstore_pagination_basic(): """Test basic pagination functionality at the SQL store level.""" with TemporaryDirectory() as tmp_dir: