feat(storage): share sql/kv instances and add upsert support (#4140)
Some checks failed
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 0s
SqlStore Integration Tests / test-postgres (3.12) (push) Failing after 0s
SqlStore Integration Tests / test-postgres (3.13) (push) Failing after 0s
Integration Tests (Replay) / generate-matrix (push) Successful in 3s
Test External Providers Installed via Module / test-external-providers-from-module (venv) (push) Has been skipped
Test Llama Stack Build / generate-matrix (push) Successful in 2s
Python Package Build Test / build (3.12) (push) Failing after 4s
API Conformance Tests / check-schema-compatibility (push) Successful in 11s
Python Package Build Test / build (3.13) (push) Failing after 17s
Test Llama Stack Build / build-single-provider (push) Successful in 31s
Test External API and Providers / test-external (venv) (push) Failing after 32s
Vector IO Integration Tests / test-matrix (push) Failing after 45s
Test Llama Stack Build / build (push) Successful in 47s
UI Tests / ui-tests (22) (push) Successful in 1m42s
Test Llama Stack Build / build-ubi9-container-distribution (push) Successful in 2m8s
Unit Tests / unit-tests (3.13) (push) Failing after 2m7s
Unit Tests / unit-tests (3.12) (push) Failing after 2m28s
Test Llama Stack Build / build-custom-container-distribution (push) Successful in 2m32s
Pre-commit / pre-commit (push) Successful in 3m20s
Integration Tests (Replay) / Integration Tests (, , , client=, ) (push) Failing after 3m33s

A few changes to the storage layer to ensure we reduce unnecessary
contention arising out of our design choices (and letting the database
layer do its correct thing):

- SQL stores now share a single `SqlAlchemySqlStoreImpl` per backend,
and `kvstore_impl` caches instances per `(backend, namespace)`. This
avoids spawning multiple SQLite connections for the same file, reducing
lock contention and aligning the cache story for all backends.

- Added an async upsert API (with SQLite/Postgres dialect inserts) and
routed it through `AuthorizedSqlStore`, then switched conversations and
responses to call it. Using native `ON CONFLICT DO UPDATE` eliminates
the insert-then-update retry window that previously caused long WAL lock
retries.

### Test Plan

Existing tests, added a unit test for `upsert()`
This commit is contained in:
Ashwin Bharambe 2025-11-12 12:14:26 -08:00 committed by GitHub
parent 492f79ca9b
commit fcf649b97a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 172 additions and 51 deletions

View file

@ -203,16 +203,11 @@ class ConversationServiceImpl(Conversations):
"item_data": item_dict,
}
# TODO: Add support for upsert in sql_store, this will fail first if ID exists and then update
try:
await self.sql_store.insert(table="conversation_items", data=item_record)
except Exception:
# If insert fails due to ID conflict, update existing record
await self.sql_store.update(
table="conversation_items",
data={"created_at": created_at, "item_data": item_dict},
where={"id": item_id},
)
await self.sql_store.upsert(
table="conversation_items",
data=item_record,
conflict_columns=["id"],
)
created_items.append(item_dict)

View file

@ -11,6 +11,9 @@
from __future__ import annotations
import asyncio
from collections import defaultdict
from llama_stack.core.storage.datatypes import KVStoreReference, StorageBackendConfig, StorageBackendType
from .api import KVStore
@ -53,45 +56,63 @@ class InmemoryKVStoreImpl(KVStore):
_KVSTORE_BACKENDS: dict[str, KVStoreConfig] = {}
_KVSTORE_INSTANCES: dict[tuple[str, str], KVStore] = {}
_KVSTORE_LOCKS: defaultdict[tuple[str, str], asyncio.Lock] = defaultdict(asyncio.Lock)
def register_kvstore_backends(backends: dict[str, StorageBackendConfig]) -> None:
"""Register the set of available KV store backends for reference resolution."""
global _KVSTORE_BACKENDS
global _KVSTORE_INSTANCES
global _KVSTORE_LOCKS
_KVSTORE_BACKENDS.clear()
_KVSTORE_INSTANCES.clear()
_KVSTORE_LOCKS.clear()
for name, cfg in backends.items():
_KVSTORE_BACKENDS[name] = cfg
async def kvstore_impl(reference: KVStoreReference) -> KVStore:
backend_name = reference.backend
cache_key = (backend_name, reference.namespace)
existing = _KVSTORE_INSTANCES.get(cache_key)
if existing:
return existing
backend_config = _KVSTORE_BACKENDS.get(backend_name)
if backend_config is None:
raise ValueError(f"Unknown KVStore backend '{backend_name}'. Registered backends: {sorted(_KVSTORE_BACKENDS)}")
config = backend_config.model_copy()
config.namespace = reference.namespace
lock = _KVSTORE_LOCKS[cache_key]
async with lock:
existing = _KVSTORE_INSTANCES.get(cache_key)
if existing:
return existing
if config.type == StorageBackendType.KV_REDIS.value:
from .redis import RedisKVStoreImpl
config = backend_config.model_copy()
config.namespace = reference.namespace
impl = RedisKVStoreImpl(config)
elif config.type == StorageBackendType.KV_SQLITE.value:
from .sqlite import SqliteKVStoreImpl
if config.type == StorageBackendType.KV_REDIS.value:
from .redis import RedisKVStoreImpl
impl = SqliteKVStoreImpl(config)
elif config.type == StorageBackendType.KV_POSTGRES.value:
from .postgres import PostgresKVStoreImpl
impl = RedisKVStoreImpl(config)
elif config.type == StorageBackendType.KV_SQLITE.value:
from .sqlite import SqliteKVStoreImpl
impl = PostgresKVStoreImpl(config)
elif config.type == StorageBackendType.KV_MONGODB.value:
from .mongodb import MongoDBKVStoreImpl
impl = SqliteKVStoreImpl(config)
elif config.type == StorageBackendType.KV_POSTGRES.value:
from .postgres import PostgresKVStoreImpl
impl = MongoDBKVStoreImpl(config)
else:
raise ValueError(f"Unknown kvstore type {config.type}")
impl = PostgresKVStoreImpl(config)
elif config.type == StorageBackendType.KV_MONGODB.value:
from .mongodb import MongoDBKVStoreImpl
await impl.initialize()
return impl
impl = MongoDBKVStoreImpl(config)
else:
raise ValueError(f"Unknown kvstore type {config.type}")
await impl.initialize()
_KVSTORE_INSTANCES[cache_key] = impl
return impl

View file

@ -252,19 +252,12 @@ class ResponsesStore:
# Serialize messages to dict format for JSON storage
messages_data = [msg.model_dump() for msg in messages]
# Upsert: try insert first, update if exists
try:
await self.sql_store.insert(
table="conversation_messages",
data={"conversation_id": conversation_id, "messages": messages_data},
)
except Exception:
# If insert fails due to ID conflict, update existing record
await self.sql_store.update(
table="conversation_messages",
data={"messages": messages_data},
where={"conversation_id": conversation_id},
)
await self.sql_store.upsert(
table="conversation_messages",
data={"conversation_id": conversation_id, "messages": messages_data},
conflict_columns=["conversation_id"],
update_columns=["messages"],
)
logger.debug(f"Stored {len(messages)} messages for conversation {conversation_id}")

View file

@ -47,6 +47,18 @@ class SqlStore(Protocol):
"""
pass
async def upsert(
self,
table: str,
data: Mapping[str, Any],
conflict_columns: list[str],
update_columns: list[str] | None = None,
) -> None:
"""
Insert a row and update specified columns when conflicts occur.
"""
pass
async def fetch_all(
self,
table: str,

View file

@ -129,6 +129,23 @@ class AuthorizedSqlStore:
enhanced_data = [_enhance_item_with_access_control(item, current_user) for item in data]
await self.sql_store.insert(table, enhanced_data)
async def upsert(
self,
table: str,
data: Mapping[str, Any],
conflict_columns: list[str],
update_columns: list[str] | None = None,
) -> None:
"""Upsert a row with automatic access control attribute capture."""
current_user = get_authenticated_user()
enhanced_data = _enhance_item_with_access_control(data, current_user)
await self.sql_store.upsert(
table=table,
data=enhanced_data,
conflict_columns=conflict_columns,
update_columns=update_columns,
)
async def fetch_all(
self,
table: str,

View file

@ -72,13 +72,14 @@ def _build_where_expr(column: ColumnElement, value: Any) -> ColumnElement:
class SqlAlchemySqlStoreImpl(SqlStore):
def __init__(self, config: SqlAlchemySqlStoreConfig):
self.config = config
self._is_sqlite_backend = "sqlite" in self.config.engine_str
self.async_session = async_sessionmaker(self.create_engine())
self.metadata = MetaData()
def create_engine(self) -> AsyncEngine:
# Configure connection args for better concurrency support
connect_args = {}
if "sqlite" in self.config.engine_str:
if self._is_sqlite_backend:
# SQLite-specific optimizations for concurrent access
# With WAL mode, most locks resolve in milliseconds, but allow up to 5s for edge cases
connect_args["timeout"] = 5.0
@ -91,7 +92,7 @@ class SqlAlchemySqlStoreImpl(SqlStore):
)
# Enable WAL mode for SQLite to support concurrent readers and writers
if "sqlite" in self.config.engine_str:
if self._is_sqlite_backend:
@event.listens_for(engine.sync_engine, "connect")
def set_sqlite_pragma(dbapi_conn, connection_record):
@ -151,6 +152,29 @@ class SqlAlchemySqlStoreImpl(SqlStore):
await session.execute(self.metadata.tables[table].insert(), data)
await session.commit()
async def upsert(
self,
table: str,
data: Mapping[str, Any],
conflict_columns: list[str],
update_columns: list[str] | None = None,
) -> None:
table_obj = self.metadata.tables[table]
dialect_insert = self._get_dialect_insert(table_obj)
insert_stmt = dialect_insert.values(**data)
if update_columns is None:
update_columns = [col for col in data.keys() if col not in conflict_columns]
update_mapping = {col: getattr(insert_stmt.excluded, col) for col in update_columns}
conflict_cols = [table_obj.c[col] for col in conflict_columns]
stmt = insert_stmt.on_conflict_do_update(index_elements=conflict_cols, set_=update_mapping)
async with self.async_session() as session:
await session.execute(stmt)
await session.commit()
async def fetch_all(
self,
table: str,
@ -333,9 +357,18 @@ class SqlAlchemySqlStoreImpl(SqlStore):
add_column_sql = text(f"ALTER TABLE {table} ADD COLUMN {column_name} {compiled_type}{nullable_clause}")
await conn.execute(add_column_sql)
except Exception as e:
# If any error occurs during migration, log it but don't fail
# The table creation will handle adding the column
logger.error(f"Error adding column {column_name} to table {table}: {e}")
pass
def _get_dialect_insert(self, table: Table):
if self._is_sqlite_backend:
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
return sqlite_insert(table)
else:
from sqlalchemy.dialects.postgresql import insert as pg_insert
return pg_insert(table)

View file

@ -4,6 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from threading import Lock
from typing import Annotated, cast
from pydantic import Field
@ -21,6 +22,8 @@ from .api import SqlStore
sql_store_pip_packages = ["sqlalchemy[asyncio]", "aiosqlite", "asyncpg"]
_SQLSTORE_BACKENDS: dict[str, StorageBackendConfig] = {}
_SQLSTORE_INSTANCES: dict[str, SqlStore] = {}
_SQLSTORE_LOCKS: dict[str, Lock] = {}
SqlStoreConfig = Annotated[
@ -52,19 +55,34 @@ def sqlstore_impl(reference: SqlStoreReference) -> SqlStore:
f"Unknown SQL store backend '{backend_name}'. Registered backends: {sorted(_SQLSTORE_BACKENDS)}"
)
if isinstance(backend_config, SqliteSqlStoreConfig | PostgresSqlStoreConfig):
from .sqlalchemy_sqlstore import SqlAlchemySqlStoreImpl
existing = _SQLSTORE_INSTANCES.get(backend_name)
if existing:
return existing
config = cast(SqliteSqlStoreConfig | PostgresSqlStoreConfig, backend_config).model_copy()
return SqlAlchemySqlStoreImpl(config)
else:
raise ValueError(f"Unknown sqlstore type {backend_config.type}")
lock = _SQLSTORE_LOCKS.setdefault(backend_name, Lock())
with lock:
existing = _SQLSTORE_INSTANCES.get(backend_name)
if existing:
return existing
if isinstance(backend_config, SqliteSqlStoreConfig | PostgresSqlStoreConfig):
from .sqlalchemy_sqlstore import SqlAlchemySqlStoreImpl
config = cast(SqliteSqlStoreConfig | PostgresSqlStoreConfig, backend_config).model_copy()
instance = SqlAlchemySqlStoreImpl(config)
_SQLSTORE_INSTANCES[backend_name] = instance
return instance
else:
raise ValueError(f"Unknown sqlstore type {backend_config.type}")
def register_sqlstore_backends(backends: dict[str, StorageBackendConfig]) -> None:
"""Register the set of available SQL store backends for reference resolution."""
global _SQLSTORE_BACKENDS
global _SQLSTORE_INSTANCES
_SQLSTORE_BACKENDS.clear()
_SQLSTORE_INSTANCES.clear()
_SQLSTORE_LOCKS.clear()
for name, cfg in backends.items():
_SQLSTORE_BACKENDS[name] = cfg

View file

@ -9,7 +9,7 @@ from tempfile import TemporaryDirectory
import pytest
from llama_stack.providers.utils.sqlstore.api import ColumnType
from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType
from llama_stack.providers.utils.sqlstore.sqlalchemy_sqlstore import SqlAlchemySqlStoreImpl
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
@ -65,6 +65,38 @@ async def test_sqlite_sqlstore():
assert result.has_more is False
async def test_sqlstore_upsert_support():
with TemporaryDirectory() as tmp_dir:
db_path = tmp_dir + "/upsert.db"
store = SqlAlchemySqlStoreImpl(SqliteSqlStoreConfig(db_path=db_path))
await store.create_table(
"items",
{
"id": ColumnDefinition(type=ColumnType.STRING, primary_key=True),
"value": ColumnType.STRING,
"updated_at": ColumnType.INTEGER,
},
)
await store.upsert(
table="items",
data={"id": "item_1", "value": "first", "updated_at": 1},
conflict_columns=["id"],
)
row = await store.fetch_one("items", {"id": "item_1"})
assert row == {"id": "item_1", "value": "first", "updated_at": 1}
await store.upsert(
table="items",
data={"id": "item_1", "value": "second", "updated_at": 2},
conflict_columns=["id"],
update_columns=["value", "updated_at"],
)
row = await store.fetch_one("items", {"id": "item_1"})
assert row == {"id": "item_1", "value": "second", "updated_at": 2}
async def test_sqlstore_pagination_basic():
"""Test basic pagination functionality at the SQL store level."""
with TemporaryDirectory() as tmp_dir: