feat(storage): share sql/kv instances and add upsert support

Sql stores now share a single SqlAlchemySqlStoreImpl per backend, and kvstore_impl caches instances per (backend, namespace). This avoids spawning multiple SQLite connections for the same file, reducing lock contention and aligning the cache story for all backends.

Added an async upsert API (with SQLite/Postgres dialect inserts) and routed it through AuthorizedSqlStore, then switched conversations and responses to call it. Using native ON CONFLICT DO UPDATE eliminates the insert-then-update retry window that previously caused long WAL lock retries.

Introduced an opt-in conversation stress test that mirrors the recorded prompts from test_conversation_multi_turn_and_streaming while fanning them out across many threads. This gives us a fast local way to hammer the conversations/responses sync path when investigating lockups.
This commit is contained in:
Ashwin Bharambe 2025-11-12 11:27:35 -08:00
parent 531e00366f
commit c2c5ab622c
9 changed files with 300 additions and 51 deletions

View file

@ -203,16 +203,11 @@ class ConversationServiceImpl(Conversations):
"item_data": item_dict,
}
# TODO: Add support for upsert in sql_store, this will fail first if ID exists and then update
try:
await self.sql_store.insert(table="conversation_items", data=item_record)
except Exception:
# If insert fails due to ID conflict, update existing record
await self.sql_store.update(
table="conversation_items",
data={"created_at": created_at, "item_data": item_dict},
where={"id": item_id},
)
await self.sql_store.upsert(
table="conversation_items",
data=item_record,
conflict_columns=["id"],
)
created_items.append(item_dict)

View file

@ -11,6 +11,9 @@
from __future__ import annotations
import asyncio
from collections import defaultdict
from llama_stack.core.storage.datatypes import KVStoreReference, StorageBackendConfig, StorageBackendType
from .api import KVStore
@ -53,45 +56,63 @@ class InmemoryKVStoreImpl(KVStore):
_KVSTORE_BACKENDS: dict[str, KVStoreConfig] = {}
_KVSTORE_INSTANCES: dict[tuple[str, str], KVStore] = {}
_KVSTORE_LOCKS: defaultdict[tuple[str, str], asyncio.Lock] = defaultdict(asyncio.Lock)
def register_kvstore_backends(backends: dict[str, StorageBackendConfig]) -> None:
"""Register the set of available KV store backends for reference resolution."""
global _KVSTORE_BACKENDS
global _KVSTORE_INSTANCES
global _KVSTORE_LOCKS
_KVSTORE_BACKENDS.clear()
_KVSTORE_INSTANCES.clear()
_KVSTORE_LOCKS.clear()
for name, cfg in backends.items():
_KVSTORE_BACKENDS[name] = cfg
async def kvstore_impl(reference: KVStoreReference) -> KVStore:
backend_name = reference.backend
cache_key = (backend_name, reference.namespace)
existing = _KVSTORE_INSTANCES.get(cache_key)
if existing:
return existing
backend_config = _KVSTORE_BACKENDS.get(backend_name)
if backend_config is None:
raise ValueError(f"Unknown KVStore backend '{backend_name}'. Registered backends: {sorted(_KVSTORE_BACKENDS)}")
config = backend_config.model_copy()
config.namespace = reference.namespace
lock = _KVSTORE_LOCKS[cache_key]
async with lock:
existing = _KVSTORE_INSTANCES.get(cache_key)
if existing:
return existing
if config.type == StorageBackendType.KV_REDIS.value:
from .redis import RedisKVStoreImpl
config = backend_config.model_copy()
config.namespace = reference.namespace
impl = RedisKVStoreImpl(config)
elif config.type == StorageBackendType.KV_SQLITE.value:
from .sqlite import SqliteKVStoreImpl
if config.type == StorageBackendType.KV_REDIS.value:
from .redis import RedisKVStoreImpl
impl = SqliteKVStoreImpl(config)
elif config.type == StorageBackendType.KV_POSTGRES.value:
from .postgres import PostgresKVStoreImpl
impl = RedisKVStoreImpl(config)
elif config.type == StorageBackendType.KV_SQLITE.value:
from .sqlite import SqliteKVStoreImpl
impl = PostgresKVStoreImpl(config)
elif config.type == StorageBackendType.KV_MONGODB.value:
from .mongodb import MongoDBKVStoreImpl
impl = SqliteKVStoreImpl(config)
elif config.type == StorageBackendType.KV_POSTGRES.value:
from .postgres import PostgresKVStoreImpl
impl = MongoDBKVStoreImpl(config)
else:
raise ValueError(f"Unknown kvstore type {config.type}")
impl = PostgresKVStoreImpl(config)
elif config.type == StorageBackendType.KV_MONGODB.value:
from .mongodb import MongoDBKVStoreImpl
await impl.initialize()
return impl
impl = MongoDBKVStoreImpl(config)
else:
raise ValueError(f"Unknown kvstore type {config.type}")
await impl.initialize()
_KVSTORE_INSTANCES[cache_key] = impl
return impl

View file

@ -252,19 +252,12 @@ class ResponsesStore:
# Serialize messages to dict format for JSON storage
messages_data = [msg.model_dump() for msg in messages]
# Upsert: try insert first, update if exists
try:
await self.sql_store.insert(
table="conversation_messages",
data={"conversation_id": conversation_id, "messages": messages_data},
)
except Exception:
# If insert fails due to ID conflict, update existing record
await self.sql_store.update(
table="conversation_messages",
data={"messages": messages_data},
where={"conversation_id": conversation_id},
)
await self.sql_store.upsert(
table="conversation_messages",
data={"conversation_id": conversation_id, "messages": messages_data},
conflict_columns=["conversation_id"],
update_columns=["messages"],
)
logger.debug(f"Stored {len(messages)} messages for conversation {conversation_id}")

View file

@ -47,6 +47,18 @@ class SqlStore(Protocol):
"""
pass
async def upsert(
self,
table: str,
data: Mapping[str, Any],
conflict_columns: list[str],
update_columns: list[str] | None = None,
) -> None:
"""
Insert a row and update specified columns when conflicts occur.
"""
pass
async def fetch_all(
self,
table: str,

View file

@ -129,6 +129,23 @@ class AuthorizedSqlStore:
enhanced_data = [_enhance_item_with_access_control(item, current_user) for item in data]
await self.sql_store.insert(table, enhanced_data)
async def upsert(
self,
table: str,
data: Mapping[str, Any],
conflict_columns: list[str],
update_columns: list[str] | None = None,
) -> None:
"""Upsert a row with automatic access control attribute capture."""
current_user = get_authenticated_user()
enhanced_data = _enhance_item_with_access_control(data, current_user)
await self.sql_store.upsert(
table=table,
data=enhanced_data,
conflict_columns=conflict_columns,
update_columns=update_columns,
)
async def fetch_all(
self,
table: str,

View file

@ -72,13 +72,14 @@ def _build_where_expr(column: ColumnElement, value: Any) -> ColumnElement:
class SqlAlchemySqlStoreImpl(SqlStore):
def __init__(self, config: SqlAlchemySqlStoreConfig):
self.config = config
self._is_sqlite_backend = "sqlite" in self.config.engine_str
self.async_session = async_sessionmaker(self.create_engine())
self.metadata = MetaData()
def create_engine(self) -> AsyncEngine:
# Configure connection args for better concurrency support
connect_args = {}
if "sqlite" in self.config.engine_str:
if self._is_sqlite_backend:
# SQLite-specific optimizations for concurrent access
# With WAL mode, most locks resolve in milliseconds, but allow up to 5s for edge cases
connect_args["timeout"] = 5.0
@ -91,7 +92,7 @@ class SqlAlchemySqlStoreImpl(SqlStore):
)
# Enable WAL mode for SQLite to support concurrent readers and writers
if "sqlite" in self.config.engine_str:
if self._is_sqlite_backend:
@event.listens_for(engine.sync_engine, "connect")
def set_sqlite_pragma(dbapi_conn, connection_record):
@ -151,6 +152,29 @@ class SqlAlchemySqlStoreImpl(SqlStore):
await session.execute(self.metadata.tables[table].insert(), data)
await session.commit()
async def upsert(
self,
table: str,
data: Mapping[str, Any],
conflict_columns: list[str],
update_columns: list[str] | None = None,
) -> None:
table_obj = self.metadata.tables[table]
dialect_insert = self._get_dialect_insert(table_obj)
insert_stmt = dialect_insert.values(**data)
if update_columns is None:
update_columns = [col for col in data.keys() if col not in conflict_columns]
update_mapping = {col: getattr(insert_stmt.excluded, col) for col in update_columns}
conflict_cols = [table_obj.c[col] for col in conflict_columns]
stmt = insert_stmt.on_conflict_do_update(index_elements=conflict_cols, set_=update_mapping)
async with self.async_session() as session:
await session.execute(stmt)
await session.commit()
async def fetch_all(
self,
table: str,
@ -333,9 +357,18 @@ class SqlAlchemySqlStoreImpl(SqlStore):
add_column_sql = text(f"ALTER TABLE {table} ADD COLUMN {column_name} {compiled_type}{nullable_clause}")
await conn.execute(add_column_sql)
except Exception as e:
# If any error occurs during migration, log it but don't fail
# The table creation will handle adding the column
logger.error(f"Error adding column {column_name} to table {table}: {e}")
pass
def _get_dialect_insert(self, table: Table):
if self._is_sqlite_backend:
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
return sqlite_insert(table)
else:
from sqlalchemy.dialects.postgresql import insert as pg_insert
return pg_insert(table)

View file

@ -4,6 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from threading import Lock
from typing import Annotated, cast
from pydantic import Field
@ -21,6 +22,8 @@ from .api import SqlStore
sql_store_pip_packages = ["sqlalchemy[asyncio]", "aiosqlite", "asyncpg"]
_SQLSTORE_BACKENDS: dict[str, StorageBackendConfig] = {}
_SQLSTORE_INSTANCES: dict[str, SqlStore] = {}
_SQLSTORE_LOCKS: dict[str, Lock] = {}
SqlStoreConfig = Annotated[
@ -52,19 +55,34 @@ def sqlstore_impl(reference: SqlStoreReference) -> SqlStore:
f"Unknown SQL store backend '{backend_name}'. Registered backends: {sorted(_SQLSTORE_BACKENDS)}"
)
if isinstance(backend_config, SqliteSqlStoreConfig | PostgresSqlStoreConfig):
from .sqlalchemy_sqlstore import SqlAlchemySqlStoreImpl
existing = _SQLSTORE_INSTANCES.get(backend_name)
if existing:
return existing
config = cast(SqliteSqlStoreConfig | PostgresSqlStoreConfig, backend_config).model_copy()
return SqlAlchemySqlStoreImpl(config)
else:
raise ValueError(f"Unknown sqlstore type {backend_config.type}")
lock = _SQLSTORE_LOCKS.setdefault(backend_name, Lock())
with lock:
existing = _SQLSTORE_INSTANCES.get(backend_name)
if existing:
return existing
if isinstance(backend_config, SqliteSqlStoreConfig | PostgresSqlStoreConfig):
from .sqlalchemy_sqlstore import SqlAlchemySqlStoreImpl
config = cast(SqliteSqlStoreConfig | PostgresSqlStoreConfig, backend_config).model_copy()
instance = SqlAlchemySqlStoreImpl(config)
_SQLSTORE_INSTANCES[backend_name] = instance
return instance
else:
raise ValueError(f"Unknown sqlstore type {backend_config.type}")
def register_sqlstore_backends(backends: dict[str, StorageBackendConfig]) -> None:
"""Register the set of available SQL store backends for reference resolution."""
global _SQLSTORE_BACKENDS
global _SQLSTORE_INSTANCES
_SQLSTORE_BACKENDS.clear()
_SQLSTORE_INSTANCES.clear()
_SQLSTORE_LOCKS.clear()
for name, cfg in backends.items():
_SQLSTORE_BACKENDS[name] = cfg

View file

@ -0,0 +1,128 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
"""
Opt-in concurrency stress tests for the ResponsesConversations sync path.
By default this module is skipped to avoid lengthening CI. Set the environment
variable LLAMA_STACK_ENABLE_CONCURRENCY_TESTS=1 to run it locally. Optional
knobs:
LLAMA_STACK_CONCURRENCY_WORKERS (default: 4)
LLAMA_STACK_CONCURRENCY_TURNS (default: 3)
"""
from __future__ import annotations
import os
from concurrent.futures import ThreadPoolExecutor, as_completed
import pytest
from llama_stack.core.testing_context import reset_test_context, set_test_context
if not os.getenv("LLAMA_STACK_ENABLE_CONCURRENCY_TESTS"):
pytest.skip("Set LLAMA_STACK_ENABLE_CONCURRENCY_TESTS=1 to run stress tests", allow_module_level=True)
CONCURRENCY_WORKERS = int(os.getenv("LLAMA_STACK_CONCURRENCY_WORKERS", "4"))
TURNS_PER_WORKER = int(os.getenv("LLAMA_STACK_CONCURRENCY_TURNS", "1"))
# Reuse the exact prompts from test_conversation_responses to hit existing recordings.
FIRST_TURN_PROMPT = "Say hello"
STREAMING_PROMPT = "Say goodbye"
RECORDED_TEST_NODE_ID = (
"tests/integration/responses/test_conversation_responses.py::TestConversationResponses::"
"test_conversation_multi_turn_and_streaming[txt=openai/gpt-4o]"
)
@pytest.mark.integration
def test_conversation_streaming_multi_client(openai_client, text_model_id):
"""Run many copies of the conversation streaming flow concurrently (matching recorded prompts)."""
if CONCURRENCY_WORKERS < 2:
pytest.skip("Need at least 2 workers for concurrency stress")
def stream_single_turn(conversation_id: str, worker_idx: int, turn_idx: int):
"""Send the streaming turn and ensure completion."""
response_stream = openai_client.responses.create(
model=text_model_id,
input=[{"role": "user", "content": STREAMING_PROMPT}],
conversation=conversation_id,
stream=True,
)
final_response = None
failed_error: str | None = None
for chunk in response_stream:
chunk_type = getattr(chunk, "type", None)
if chunk_type is None and isinstance(chunk, dict):
chunk_type = chunk.get("type")
chunk_response = getattr(chunk, "response", None)
if chunk_response is None and isinstance(chunk, dict):
chunk_response = chunk.get("response")
if chunk_type in {"response.completed", "response.incomplete"}:
final_response = chunk_response
break
if chunk_type == "response.failed":
error_obj = None
if chunk_response is not None:
error_obj = getattr(chunk_response, "error", None)
if error_obj is None and isinstance(chunk_response, dict):
error_obj = chunk_response.get("error")
if isinstance(error_obj, dict):
failed_error = error_obj.get("message", "response.failed")
elif hasattr(error_obj, "message"):
failed_error = error_obj.message # type: ignore[assignment]
else:
failed_error = "response.failed event without error"
break
if failed_error:
raise RuntimeError(f"Worker {worker_idx} turn {turn_idx} failed: {failed_error}")
if final_response is None:
raise RuntimeError(f"Worker {worker_idx} turn {turn_idx} did not complete")
def worker(worker_idx: int) -> list[str]:
token = set_test_context(RECORDED_TEST_NODE_ID)
conversation_ids: list[str] = []
try:
for turn_idx in range(TURNS_PER_WORKER):
conversation = openai_client.conversations.create()
assert conversation.id.startswith("conv_")
conversation_ids.append(conversation.id)
openai_client.responses.create(
model=text_model_id,
input=[{"role": "user", "content": FIRST_TURN_PROMPT}],
conversation=conversation.id,
)
stream_single_turn(conversation.id, worker_idx, turn_idx)
return conversation_ids
finally:
reset_test_context(token)
all_conversation_ids: list[str] = []
futures = []
with ThreadPoolExecutor(max_workers=CONCURRENCY_WORKERS) as executor:
for worker_idx in range(CONCURRENCY_WORKERS):
futures.append(executor.submit(worker, worker_idx))
for future in as_completed(futures):
all_conversation_ids.extend(future.result())
expected_conversations = CONCURRENCY_WORKERS * TURNS_PER_WORKER
assert len(all_conversation_ids) == expected_conversations
for conversation_id in all_conversation_ids:
conversation_items = openai_client.conversations.items.list(conversation_id)
assert len(conversation_items.data) >= 4, (
f"Conversation {conversation_id} missing expected items: {len(conversation_items.data)}"
)

View file

@ -9,7 +9,7 @@ from tempfile import TemporaryDirectory
import pytest
from llama_stack.providers.utils.sqlstore.api import ColumnType
from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType
from llama_stack.providers.utils.sqlstore.sqlalchemy_sqlstore import SqlAlchemySqlStoreImpl
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
@ -65,6 +65,38 @@ async def test_sqlite_sqlstore():
assert result.has_more is False
async def test_sqlstore_upsert_support():
with TemporaryDirectory() as tmp_dir:
db_path = tmp_dir + "/upsert.db"
store = SqlAlchemySqlStoreImpl(SqliteSqlStoreConfig(db_path=db_path))
await store.create_table(
"items",
{
"id": ColumnDefinition(type=ColumnType.STRING, primary_key=True),
"value": ColumnType.STRING,
"updated_at": ColumnType.INTEGER,
},
)
await store.upsert(
table="items",
data={"id": "item_1", "value": "first", "updated_at": 1},
conflict_columns=["id"],
)
row = await store.fetch_one("items", {"id": "item_1"})
assert row == {"id": "item_1", "value": "first", "updated_at": 1}
await store.upsert(
table="items",
data={"id": "item_1", "value": "second", "updated_at": 2},
conflict_columns=["id"],
update_columns=["value", "updated_at"],
)
row = await store.fetch_one("items", {"id": "item_1"})
assert row == {"id": "item_1", "value": "second", "updated_at": 2}
async def test_sqlstore_pagination_basic():
"""Test basic pagination functionality at the SQL store level."""
with TemporaryDirectory() as tmp_dir: