mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 09:53:45 +00:00
feat(storage): share sql/kv instances and add upsert support
Sql stores now share a single SqlAlchemySqlStoreImpl per backend, and kvstore_impl caches instances per (backend, namespace). This avoids spawning multiple SQLite connections for the same file, reducing lock contention and aligning the cache story for all backends. Added an async upsert API (with SQLite/Postgres dialect inserts) and routed it through AuthorizedSqlStore, then switched conversations and responses to call it. Using native ON CONFLICT DO UPDATE eliminates the insert-then-update retry window that previously caused long WAL lock retries. Introduced an opt-in conversation stress test that mirrors the recorded prompts from test_conversation_multi_turn_and_streaming while fanning them out across many threads. This gives us a fast local way to hammer the conversations/responses sync path when investigating lockups.
This commit is contained in:
parent
531e00366f
commit
c2c5ab622c
9 changed files with 300 additions and 51 deletions
|
|
@ -203,16 +203,11 @@ class ConversationServiceImpl(Conversations):
|
|||
"item_data": item_dict,
|
||||
}
|
||||
|
||||
# TODO: Add support for upsert in sql_store, this will fail first if ID exists and then update
|
||||
try:
|
||||
await self.sql_store.insert(table="conversation_items", data=item_record)
|
||||
except Exception:
|
||||
# If insert fails due to ID conflict, update existing record
|
||||
await self.sql_store.update(
|
||||
table="conversation_items",
|
||||
data={"created_at": created_at, "item_data": item_dict},
|
||||
where={"id": item_id},
|
||||
)
|
||||
await self.sql_store.upsert(
|
||||
table="conversation_items",
|
||||
data=item_record,
|
||||
conflict_columns=["id"],
|
||||
)
|
||||
|
||||
created_items.append(item_dict)
|
||||
|
||||
|
|
|
|||
|
|
@ -11,6 +11,9 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections import defaultdict
|
||||
|
||||
from llama_stack.core.storage.datatypes import KVStoreReference, StorageBackendConfig, StorageBackendType
|
||||
|
||||
from .api import KVStore
|
||||
|
|
@ -53,45 +56,63 @@ class InmemoryKVStoreImpl(KVStore):
|
|||
|
||||
|
||||
_KVSTORE_BACKENDS: dict[str, KVStoreConfig] = {}
|
||||
_KVSTORE_INSTANCES: dict[tuple[str, str], KVStore] = {}
|
||||
_KVSTORE_LOCKS: defaultdict[tuple[str, str], asyncio.Lock] = defaultdict(asyncio.Lock)
|
||||
|
||||
|
||||
def register_kvstore_backends(backends: dict[str, StorageBackendConfig]) -> None:
|
||||
"""Register the set of available KV store backends for reference resolution."""
|
||||
global _KVSTORE_BACKENDS
|
||||
global _KVSTORE_INSTANCES
|
||||
global _KVSTORE_LOCKS
|
||||
|
||||
_KVSTORE_BACKENDS.clear()
|
||||
_KVSTORE_INSTANCES.clear()
|
||||
_KVSTORE_LOCKS.clear()
|
||||
for name, cfg in backends.items():
|
||||
_KVSTORE_BACKENDS[name] = cfg
|
||||
|
||||
|
||||
async def kvstore_impl(reference: KVStoreReference) -> KVStore:
|
||||
backend_name = reference.backend
|
||||
cache_key = (backend_name, reference.namespace)
|
||||
|
||||
existing = _KVSTORE_INSTANCES.get(cache_key)
|
||||
if existing:
|
||||
return existing
|
||||
|
||||
backend_config = _KVSTORE_BACKENDS.get(backend_name)
|
||||
if backend_config is None:
|
||||
raise ValueError(f"Unknown KVStore backend '{backend_name}'. Registered backends: {sorted(_KVSTORE_BACKENDS)}")
|
||||
|
||||
config = backend_config.model_copy()
|
||||
config.namespace = reference.namespace
|
||||
lock = _KVSTORE_LOCKS[cache_key]
|
||||
async with lock:
|
||||
existing = _KVSTORE_INSTANCES.get(cache_key)
|
||||
if existing:
|
||||
return existing
|
||||
|
||||
if config.type == StorageBackendType.KV_REDIS.value:
|
||||
from .redis import RedisKVStoreImpl
|
||||
config = backend_config.model_copy()
|
||||
config.namespace = reference.namespace
|
||||
|
||||
impl = RedisKVStoreImpl(config)
|
||||
elif config.type == StorageBackendType.KV_SQLITE.value:
|
||||
from .sqlite import SqliteKVStoreImpl
|
||||
if config.type == StorageBackendType.KV_REDIS.value:
|
||||
from .redis import RedisKVStoreImpl
|
||||
|
||||
impl = SqliteKVStoreImpl(config)
|
||||
elif config.type == StorageBackendType.KV_POSTGRES.value:
|
||||
from .postgres import PostgresKVStoreImpl
|
||||
impl = RedisKVStoreImpl(config)
|
||||
elif config.type == StorageBackendType.KV_SQLITE.value:
|
||||
from .sqlite import SqliteKVStoreImpl
|
||||
|
||||
impl = PostgresKVStoreImpl(config)
|
||||
elif config.type == StorageBackendType.KV_MONGODB.value:
|
||||
from .mongodb import MongoDBKVStoreImpl
|
||||
impl = SqliteKVStoreImpl(config)
|
||||
elif config.type == StorageBackendType.KV_POSTGRES.value:
|
||||
from .postgres import PostgresKVStoreImpl
|
||||
|
||||
impl = MongoDBKVStoreImpl(config)
|
||||
else:
|
||||
raise ValueError(f"Unknown kvstore type {config.type}")
|
||||
impl = PostgresKVStoreImpl(config)
|
||||
elif config.type == StorageBackendType.KV_MONGODB.value:
|
||||
from .mongodb import MongoDBKVStoreImpl
|
||||
|
||||
await impl.initialize()
|
||||
return impl
|
||||
impl = MongoDBKVStoreImpl(config)
|
||||
else:
|
||||
raise ValueError(f"Unknown kvstore type {config.type}")
|
||||
|
||||
await impl.initialize()
|
||||
_KVSTORE_INSTANCES[cache_key] = impl
|
||||
return impl
|
||||
|
|
|
|||
|
|
@ -252,19 +252,12 @@ class ResponsesStore:
|
|||
# Serialize messages to dict format for JSON storage
|
||||
messages_data = [msg.model_dump() for msg in messages]
|
||||
|
||||
# Upsert: try insert first, update if exists
|
||||
try:
|
||||
await self.sql_store.insert(
|
||||
table="conversation_messages",
|
||||
data={"conversation_id": conversation_id, "messages": messages_data},
|
||||
)
|
||||
except Exception:
|
||||
# If insert fails due to ID conflict, update existing record
|
||||
await self.sql_store.update(
|
||||
table="conversation_messages",
|
||||
data={"messages": messages_data},
|
||||
where={"conversation_id": conversation_id},
|
||||
)
|
||||
await self.sql_store.upsert(
|
||||
table="conversation_messages",
|
||||
data={"conversation_id": conversation_id, "messages": messages_data},
|
||||
conflict_columns=["conversation_id"],
|
||||
update_columns=["messages"],
|
||||
)
|
||||
|
||||
logger.debug(f"Stored {len(messages)} messages for conversation {conversation_id}")
|
||||
|
||||
|
|
|
|||
|
|
@ -47,6 +47,18 @@ class SqlStore(Protocol):
|
|||
"""
|
||||
pass
|
||||
|
||||
async def upsert(
|
||||
self,
|
||||
table: str,
|
||||
data: Mapping[str, Any],
|
||||
conflict_columns: list[str],
|
||||
update_columns: list[str] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Insert a row and update specified columns when conflicts occur.
|
||||
"""
|
||||
pass
|
||||
|
||||
async def fetch_all(
|
||||
self,
|
||||
table: str,
|
||||
|
|
|
|||
|
|
@ -129,6 +129,23 @@ class AuthorizedSqlStore:
|
|||
enhanced_data = [_enhance_item_with_access_control(item, current_user) for item in data]
|
||||
await self.sql_store.insert(table, enhanced_data)
|
||||
|
||||
async def upsert(
|
||||
self,
|
||||
table: str,
|
||||
data: Mapping[str, Any],
|
||||
conflict_columns: list[str],
|
||||
update_columns: list[str] | None = None,
|
||||
) -> None:
|
||||
"""Upsert a row with automatic access control attribute capture."""
|
||||
current_user = get_authenticated_user()
|
||||
enhanced_data = _enhance_item_with_access_control(data, current_user)
|
||||
await self.sql_store.upsert(
|
||||
table=table,
|
||||
data=enhanced_data,
|
||||
conflict_columns=conflict_columns,
|
||||
update_columns=update_columns,
|
||||
)
|
||||
|
||||
async def fetch_all(
|
||||
self,
|
||||
table: str,
|
||||
|
|
|
|||
|
|
@ -72,13 +72,14 @@ def _build_where_expr(column: ColumnElement, value: Any) -> ColumnElement:
|
|||
class SqlAlchemySqlStoreImpl(SqlStore):
|
||||
def __init__(self, config: SqlAlchemySqlStoreConfig):
|
||||
self.config = config
|
||||
self._is_sqlite_backend = "sqlite" in self.config.engine_str
|
||||
self.async_session = async_sessionmaker(self.create_engine())
|
||||
self.metadata = MetaData()
|
||||
|
||||
def create_engine(self) -> AsyncEngine:
|
||||
# Configure connection args for better concurrency support
|
||||
connect_args = {}
|
||||
if "sqlite" in self.config.engine_str:
|
||||
if self._is_sqlite_backend:
|
||||
# SQLite-specific optimizations for concurrent access
|
||||
# With WAL mode, most locks resolve in milliseconds, but allow up to 5s for edge cases
|
||||
connect_args["timeout"] = 5.0
|
||||
|
|
@ -91,7 +92,7 @@ class SqlAlchemySqlStoreImpl(SqlStore):
|
|||
)
|
||||
|
||||
# Enable WAL mode for SQLite to support concurrent readers and writers
|
||||
if "sqlite" in self.config.engine_str:
|
||||
if self._is_sqlite_backend:
|
||||
|
||||
@event.listens_for(engine.sync_engine, "connect")
|
||||
def set_sqlite_pragma(dbapi_conn, connection_record):
|
||||
|
|
@ -151,6 +152,29 @@ class SqlAlchemySqlStoreImpl(SqlStore):
|
|||
await session.execute(self.metadata.tables[table].insert(), data)
|
||||
await session.commit()
|
||||
|
||||
async def upsert(
|
||||
self,
|
||||
table: str,
|
||||
data: Mapping[str, Any],
|
||||
conflict_columns: list[str],
|
||||
update_columns: list[str] | None = None,
|
||||
) -> None:
|
||||
table_obj = self.metadata.tables[table]
|
||||
dialect_insert = self._get_dialect_insert(table_obj)
|
||||
insert_stmt = dialect_insert.values(**data)
|
||||
|
||||
if update_columns is None:
|
||||
update_columns = [col for col in data.keys() if col not in conflict_columns]
|
||||
|
||||
update_mapping = {col: getattr(insert_stmt.excluded, col) for col in update_columns}
|
||||
conflict_cols = [table_obj.c[col] for col in conflict_columns]
|
||||
|
||||
stmt = insert_stmt.on_conflict_do_update(index_elements=conflict_cols, set_=update_mapping)
|
||||
|
||||
async with self.async_session() as session:
|
||||
await session.execute(stmt)
|
||||
await session.commit()
|
||||
|
||||
async def fetch_all(
|
||||
self,
|
||||
table: str,
|
||||
|
|
@ -333,9 +357,18 @@ class SqlAlchemySqlStoreImpl(SqlStore):
|
|||
add_column_sql = text(f"ALTER TABLE {table} ADD COLUMN {column_name} {compiled_type}{nullable_clause}")
|
||||
|
||||
await conn.execute(add_column_sql)
|
||||
|
||||
except Exception as e:
|
||||
# If any error occurs during migration, log it but don't fail
|
||||
# The table creation will handle adding the column
|
||||
logger.error(f"Error adding column {column_name} to table {table}: {e}")
|
||||
pass
|
||||
|
||||
def _get_dialect_insert(self, table: Table):
|
||||
if self._is_sqlite_backend:
|
||||
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
|
||||
|
||||
return sqlite_insert(table)
|
||||
else:
|
||||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||
|
||||
return pg_insert(table)
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from threading import Lock
|
||||
from typing import Annotated, cast
|
||||
|
||||
from pydantic import Field
|
||||
|
|
@ -21,6 +22,8 @@ from .api import SqlStore
|
|||
sql_store_pip_packages = ["sqlalchemy[asyncio]", "aiosqlite", "asyncpg"]
|
||||
|
||||
_SQLSTORE_BACKENDS: dict[str, StorageBackendConfig] = {}
|
||||
_SQLSTORE_INSTANCES: dict[str, SqlStore] = {}
|
||||
_SQLSTORE_LOCKS: dict[str, Lock] = {}
|
||||
|
||||
|
||||
SqlStoreConfig = Annotated[
|
||||
|
|
@ -52,19 +55,34 @@ def sqlstore_impl(reference: SqlStoreReference) -> SqlStore:
|
|||
f"Unknown SQL store backend '{backend_name}'. Registered backends: {sorted(_SQLSTORE_BACKENDS)}"
|
||||
)
|
||||
|
||||
if isinstance(backend_config, SqliteSqlStoreConfig | PostgresSqlStoreConfig):
|
||||
from .sqlalchemy_sqlstore import SqlAlchemySqlStoreImpl
|
||||
existing = _SQLSTORE_INSTANCES.get(backend_name)
|
||||
if existing:
|
||||
return existing
|
||||
|
||||
config = cast(SqliteSqlStoreConfig | PostgresSqlStoreConfig, backend_config).model_copy()
|
||||
return SqlAlchemySqlStoreImpl(config)
|
||||
else:
|
||||
raise ValueError(f"Unknown sqlstore type {backend_config.type}")
|
||||
lock = _SQLSTORE_LOCKS.setdefault(backend_name, Lock())
|
||||
with lock:
|
||||
existing = _SQLSTORE_INSTANCES.get(backend_name)
|
||||
if existing:
|
||||
return existing
|
||||
|
||||
if isinstance(backend_config, SqliteSqlStoreConfig | PostgresSqlStoreConfig):
|
||||
from .sqlalchemy_sqlstore import SqlAlchemySqlStoreImpl
|
||||
|
||||
config = cast(SqliteSqlStoreConfig | PostgresSqlStoreConfig, backend_config).model_copy()
|
||||
instance = SqlAlchemySqlStoreImpl(config)
|
||||
_SQLSTORE_INSTANCES[backend_name] = instance
|
||||
return instance
|
||||
else:
|
||||
raise ValueError(f"Unknown sqlstore type {backend_config.type}")
|
||||
|
||||
|
||||
def register_sqlstore_backends(backends: dict[str, StorageBackendConfig]) -> None:
|
||||
"""Register the set of available SQL store backends for reference resolution."""
|
||||
global _SQLSTORE_BACKENDS
|
||||
global _SQLSTORE_INSTANCES
|
||||
|
||||
_SQLSTORE_BACKENDS.clear()
|
||||
_SQLSTORE_INSTANCES.clear()
|
||||
_SQLSTORE_LOCKS.clear()
|
||||
for name, cfg in backends.items():
|
||||
_SQLSTORE_BACKENDS[name] = cfg
|
||||
|
|
|
|||
128
tests/integration/responses/test_conversation_stress.py
Normal file
128
tests/integration/responses/test_conversation_stress.py
Normal file
|
|
@ -0,0 +1,128 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
"""
|
||||
Opt-in concurrency stress tests for the Responses↔Conversations sync path.
|
||||
|
||||
By default this module is skipped to avoid lengthening CI. Set the environment
|
||||
variable LLAMA_STACK_ENABLE_CONCURRENCY_TESTS=1 to run it locally. Optional
|
||||
knobs:
|
||||
|
||||
LLAMA_STACK_CONCURRENCY_WORKERS (default: 4)
|
||||
LLAMA_STACK_CONCURRENCY_TURNS (default: 3)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.core.testing_context import reset_test_context, set_test_context
|
||||
|
||||
if not os.getenv("LLAMA_STACK_ENABLE_CONCURRENCY_TESTS"):
|
||||
pytest.skip("Set LLAMA_STACK_ENABLE_CONCURRENCY_TESTS=1 to run stress tests", allow_module_level=True)
|
||||
|
||||
|
||||
CONCURRENCY_WORKERS = int(os.getenv("LLAMA_STACK_CONCURRENCY_WORKERS", "4"))
|
||||
TURNS_PER_WORKER = int(os.getenv("LLAMA_STACK_CONCURRENCY_TURNS", "1"))
|
||||
|
||||
# Reuse the exact prompts from test_conversation_responses to hit existing recordings.
|
||||
FIRST_TURN_PROMPT = "Say hello"
|
||||
STREAMING_PROMPT = "Say goodbye"
|
||||
RECORDED_TEST_NODE_ID = (
|
||||
"tests/integration/responses/test_conversation_responses.py::TestConversationResponses::"
|
||||
"test_conversation_multi_turn_and_streaming[txt=openai/gpt-4o]"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_conversation_streaming_multi_client(openai_client, text_model_id):
|
||||
"""Run many copies of the conversation streaming flow concurrently (matching recorded prompts)."""
|
||||
|
||||
if CONCURRENCY_WORKERS < 2:
|
||||
pytest.skip("Need at least 2 workers for concurrency stress")
|
||||
|
||||
def stream_single_turn(conversation_id: str, worker_idx: int, turn_idx: int):
|
||||
"""Send the streaming turn and ensure completion."""
|
||||
response_stream = openai_client.responses.create(
|
||||
model=text_model_id,
|
||||
input=[{"role": "user", "content": STREAMING_PROMPT}],
|
||||
conversation=conversation_id,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
final_response = None
|
||||
failed_error: str | None = None
|
||||
for chunk in response_stream:
|
||||
chunk_type = getattr(chunk, "type", None)
|
||||
if chunk_type is None and isinstance(chunk, dict):
|
||||
chunk_type = chunk.get("type")
|
||||
|
||||
chunk_response = getattr(chunk, "response", None)
|
||||
if chunk_response is None and isinstance(chunk, dict):
|
||||
chunk_response = chunk.get("response")
|
||||
|
||||
if chunk_type in {"response.completed", "response.incomplete"}:
|
||||
final_response = chunk_response
|
||||
break
|
||||
if chunk_type == "response.failed":
|
||||
error_obj = None
|
||||
if chunk_response is not None:
|
||||
error_obj = getattr(chunk_response, "error", None)
|
||||
if error_obj is None and isinstance(chunk_response, dict):
|
||||
error_obj = chunk_response.get("error")
|
||||
if isinstance(error_obj, dict):
|
||||
failed_error = error_obj.get("message", "response.failed")
|
||||
elif hasattr(error_obj, "message"):
|
||||
failed_error = error_obj.message # type: ignore[assignment]
|
||||
else:
|
||||
failed_error = "response.failed event without error"
|
||||
break
|
||||
|
||||
if failed_error:
|
||||
raise RuntimeError(f"Worker {worker_idx} turn {turn_idx} failed: {failed_error}")
|
||||
if final_response is None:
|
||||
raise RuntimeError(f"Worker {worker_idx} turn {turn_idx} did not complete")
|
||||
|
||||
def worker(worker_idx: int) -> list[str]:
|
||||
token = set_test_context(RECORDED_TEST_NODE_ID)
|
||||
conversation_ids: list[str] = []
|
||||
try:
|
||||
for turn_idx in range(TURNS_PER_WORKER):
|
||||
conversation = openai_client.conversations.create()
|
||||
assert conversation.id.startswith("conv_")
|
||||
conversation_ids.append(conversation.id)
|
||||
|
||||
openai_client.responses.create(
|
||||
model=text_model_id,
|
||||
input=[{"role": "user", "content": FIRST_TURN_PROMPT}],
|
||||
conversation=conversation.id,
|
||||
)
|
||||
|
||||
stream_single_turn(conversation.id, worker_idx, turn_idx)
|
||||
|
||||
return conversation_ids
|
||||
finally:
|
||||
reset_test_context(token)
|
||||
|
||||
all_conversation_ids: list[str] = []
|
||||
futures = []
|
||||
with ThreadPoolExecutor(max_workers=CONCURRENCY_WORKERS) as executor:
|
||||
for worker_idx in range(CONCURRENCY_WORKERS):
|
||||
futures.append(executor.submit(worker, worker_idx))
|
||||
|
||||
for future in as_completed(futures):
|
||||
all_conversation_ids.extend(future.result())
|
||||
|
||||
expected_conversations = CONCURRENCY_WORKERS * TURNS_PER_WORKER
|
||||
assert len(all_conversation_ids) == expected_conversations
|
||||
|
||||
for conversation_id in all_conversation_ids:
|
||||
conversation_items = openai_client.conversations.items.list(conversation_id)
|
||||
assert len(conversation_items.data) >= 4, (
|
||||
f"Conversation {conversation_id} missing expected items: {len(conversation_items.data)}"
|
||||
)
|
||||
|
|
@ -9,7 +9,7 @@ from tempfile import TemporaryDirectory
|
|||
|
||||
import pytest
|
||||
|
||||
from llama_stack.providers.utils.sqlstore.api import ColumnType
|
||||
from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType
|
||||
from llama_stack.providers.utils.sqlstore.sqlalchemy_sqlstore import SqlAlchemySqlStoreImpl
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
|
||||
|
||||
|
|
@ -65,6 +65,38 @@ async def test_sqlite_sqlstore():
|
|||
assert result.has_more is False
|
||||
|
||||
|
||||
async def test_sqlstore_upsert_support():
|
||||
with TemporaryDirectory() as tmp_dir:
|
||||
db_path = tmp_dir + "/upsert.db"
|
||||
store = SqlAlchemySqlStoreImpl(SqliteSqlStoreConfig(db_path=db_path))
|
||||
|
||||
await store.create_table(
|
||||
"items",
|
||||
{
|
||||
"id": ColumnDefinition(type=ColumnType.STRING, primary_key=True),
|
||||
"value": ColumnType.STRING,
|
||||
"updated_at": ColumnType.INTEGER,
|
||||
},
|
||||
)
|
||||
|
||||
await store.upsert(
|
||||
table="items",
|
||||
data={"id": "item_1", "value": "first", "updated_at": 1},
|
||||
conflict_columns=["id"],
|
||||
)
|
||||
row = await store.fetch_one("items", {"id": "item_1"})
|
||||
assert row == {"id": "item_1", "value": "first", "updated_at": 1}
|
||||
|
||||
await store.upsert(
|
||||
table="items",
|
||||
data={"id": "item_1", "value": "second", "updated_at": 2},
|
||||
conflict_columns=["id"],
|
||||
update_columns=["value", "updated_at"],
|
||||
)
|
||||
row = await store.fetch_one("items", {"id": "item_1"})
|
||||
assert row == {"id": "item_1", "value": "second", "updated_at": 2}
|
||||
|
||||
|
||||
async def test_sqlstore_pagination_basic():
|
||||
"""Test basic pagination functionality at the SQL store level."""
|
||||
with TemporaryDirectory() as tmp_dir:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue