mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 09:53:45 +00:00
feat(storage): share sql/kv instances and add upsert support (#4140)
Some checks failed
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 0s
SqlStore Integration Tests / test-postgres (3.12) (push) Failing after 0s
SqlStore Integration Tests / test-postgres (3.13) (push) Failing after 0s
Integration Tests (Replay) / generate-matrix (push) Successful in 3s
Test External Providers Installed via Module / test-external-providers-from-module (venv) (push) Has been skipped
Test Llama Stack Build / generate-matrix (push) Successful in 2s
Python Package Build Test / build (3.12) (push) Failing after 4s
API Conformance Tests / check-schema-compatibility (push) Successful in 11s
Python Package Build Test / build (3.13) (push) Failing after 17s
Test Llama Stack Build / build-single-provider (push) Successful in 31s
Test External API and Providers / test-external (venv) (push) Failing after 32s
Vector IO Integration Tests / test-matrix (push) Failing after 45s
Test Llama Stack Build / build (push) Successful in 47s
UI Tests / ui-tests (22) (push) Successful in 1m42s
Test Llama Stack Build / build-ubi9-container-distribution (push) Successful in 2m8s
Unit Tests / unit-tests (3.13) (push) Failing after 2m7s
Unit Tests / unit-tests (3.12) (push) Failing after 2m28s
Test Llama Stack Build / build-custom-container-distribution (push) Successful in 2m32s
Pre-commit / pre-commit (push) Successful in 3m20s
Integration Tests (Replay) / Integration Tests (, , , client=, ) (push) Failing after 3m33s
Some checks failed
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 0s
SqlStore Integration Tests / test-postgres (3.12) (push) Failing after 0s
SqlStore Integration Tests / test-postgres (3.13) (push) Failing after 0s
Integration Tests (Replay) / generate-matrix (push) Successful in 3s
Test External Providers Installed via Module / test-external-providers-from-module (venv) (push) Has been skipped
Test Llama Stack Build / generate-matrix (push) Successful in 2s
Python Package Build Test / build (3.12) (push) Failing after 4s
API Conformance Tests / check-schema-compatibility (push) Successful in 11s
Python Package Build Test / build (3.13) (push) Failing after 17s
Test Llama Stack Build / build-single-provider (push) Successful in 31s
Test External API and Providers / test-external (venv) (push) Failing after 32s
Vector IO Integration Tests / test-matrix (push) Failing after 45s
Test Llama Stack Build / build (push) Successful in 47s
UI Tests / ui-tests (22) (push) Successful in 1m42s
Test Llama Stack Build / build-ubi9-container-distribution (push) Successful in 2m8s
Unit Tests / unit-tests (3.13) (push) Failing after 2m7s
Unit Tests / unit-tests (3.12) (push) Failing after 2m28s
Test Llama Stack Build / build-custom-container-distribution (push) Successful in 2m32s
Pre-commit / pre-commit (push) Successful in 3m20s
Integration Tests (Replay) / Integration Tests (, , , client=, ) (push) Failing after 3m33s
A few changes to the storage layer to ensure we reduce unnecessary contention arising out of our design choices (and letting the database layer do its correct thing): - SQL stores now share a single `SqlAlchemySqlStoreImpl` per backend, and `kvstore_impl` caches instances per `(backend, namespace)`. This avoids spawning multiple SQLite connections for the same file, reducing lock contention and aligning the cache story for all backends. - Added an async upsert API (with SQLite/Postgres dialect inserts) and routed it through `AuthorizedSqlStore`, then switched conversations and responses to call it. Using native `ON CONFLICT DO UPDATE` eliminates the insert-then-update retry window that previously caused long WAL lock retries. ### Test Plan Existing tests, added a unit test for `upsert()`
This commit is contained in:
parent
492f79ca9b
commit
fcf649b97a
8 changed files with 172 additions and 51 deletions
|
|
@ -203,15 +203,10 @@ class ConversationServiceImpl(Conversations):
|
|||
"item_data": item_dict,
|
||||
}
|
||||
|
||||
# TODO: Add support for upsert in sql_store, this will fail first if ID exists and then update
|
||||
try:
|
||||
await self.sql_store.insert(table="conversation_items", data=item_record)
|
||||
except Exception:
|
||||
# If insert fails due to ID conflict, update existing record
|
||||
await self.sql_store.update(
|
||||
await self.sql_store.upsert(
|
||||
table="conversation_items",
|
||||
data={"created_at": created_at, "item_data": item_dict},
|
||||
where={"id": item_id},
|
||||
data=item_record,
|
||||
conflict_columns=["id"],
|
||||
)
|
||||
|
||||
created_items.append(item_dict)
|
||||
|
|
|
|||
|
|
@ -11,6 +11,9 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections import defaultdict
|
||||
|
||||
from llama_stack.core.storage.datatypes import KVStoreReference, StorageBackendConfig, StorageBackendType
|
||||
|
||||
from .api import KVStore
|
||||
|
|
@ -53,24 +56,41 @@ class InmemoryKVStoreImpl(KVStore):
|
|||
|
||||
|
||||
_KVSTORE_BACKENDS: dict[str, KVStoreConfig] = {}
|
||||
_KVSTORE_INSTANCES: dict[tuple[str, str], KVStore] = {}
|
||||
_KVSTORE_LOCKS: defaultdict[tuple[str, str], asyncio.Lock] = defaultdict(asyncio.Lock)
|
||||
|
||||
|
||||
def register_kvstore_backends(backends: dict[str, StorageBackendConfig]) -> None:
|
||||
"""Register the set of available KV store backends for reference resolution."""
|
||||
global _KVSTORE_BACKENDS
|
||||
global _KVSTORE_INSTANCES
|
||||
global _KVSTORE_LOCKS
|
||||
|
||||
_KVSTORE_BACKENDS.clear()
|
||||
_KVSTORE_INSTANCES.clear()
|
||||
_KVSTORE_LOCKS.clear()
|
||||
for name, cfg in backends.items():
|
||||
_KVSTORE_BACKENDS[name] = cfg
|
||||
|
||||
|
||||
async def kvstore_impl(reference: KVStoreReference) -> KVStore:
|
||||
backend_name = reference.backend
|
||||
cache_key = (backend_name, reference.namespace)
|
||||
|
||||
existing = _KVSTORE_INSTANCES.get(cache_key)
|
||||
if existing:
|
||||
return existing
|
||||
|
||||
backend_config = _KVSTORE_BACKENDS.get(backend_name)
|
||||
if backend_config is None:
|
||||
raise ValueError(f"Unknown KVStore backend '{backend_name}'. Registered backends: {sorted(_KVSTORE_BACKENDS)}")
|
||||
|
||||
lock = _KVSTORE_LOCKS[cache_key]
|
||||
async with lock:
|
||||
existing = _KVSTORE_INSTANCES.get(cache_key)
|
||||
if existing:
|
||||
return existing
|
||||
|
||||
config = backend_config.model_copy()
|
||||
config.namespace = reference.namespace
|
||||
|
||||
|
|
@ -94,4 +114,5 @@ async def kvstore_impl(reference: KVStoreReference) -> KVStore:
|
|||
raise ValueError(f"Unknown kvstore type {config.type}")
|
||||
|
||||
await impl.initialize()
|
||||
_KVSTORE_INSTANCES[cache_key] = impl
|
||||
return impl
|
||||
|
|
|
|||
|
|
@ -252,18 +252,11 @@ class ResponsesStore:
|
|||
# Serialize messages to dict format for JSON storage
|
||||
messages_data = [msg.model_dump() for msg in messages]
|
||||
|
||||
# Upsert: try insert first, update if exists
|
||||
try:
|
||||
await self.sql_store.insert(
|
||||
await self.sql_store.upsert(
|
||||
table="conversation_messages",
|
||||
data={"conversation_id": conversation_id, "messages": messages_data},
|
||||
)
|
||||
except Exception:
|
||||
# If insert fails due to ID conflict, update existing record
|
||||
await self.sql_store.update(
|
||||
table="conversation_messages",
|
||||
data={"messages": messages_data},
|
||||
where={"conversation_id": conversation_id},
|
||||
conflict_columns=["conversation_id"],
|
||||
update_columns=["messages"],
|
||||
)
|
||||
|
||||
logger.debug(f"Stored {len(messages)} messages for conversation {conversation_id}")
|
||||
|
|
|
|||
|
|
@ -47,6 +47,18 @@ class SqlStore(Protocol):
|
|||
"""
|
||||
pass
|
||||
|
||||
async def upsert(
|
||||
self,
|
||||
table: str,
|
||||
data: Mapping[str, Any],
|
||||
conflict_columns: list[str],
|
||||
update_columns: list[str] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Insert a row and update specified columns when conflicts occur.
|
||||
"""
|
||||
pass
|
||||
|
||||
async def fetch_all(
|
||||
self,
|
||||
table: str,
|
||||
|
|
|
|||
|
|
@ -129,6 +129,23 @@ class AuthorizedSqlStore:
|
|||
enhanced_data = [_enhance_item_with_access_control(item, current_user) for item in data]
|
||||
await self.sql_store.insert(table, enhanced_data)
|
||||
|
||||
async def upsert(
|
||||
self,
|
||||
table: str,
|
||||
data: Mapping[str, Any],
|
||||
conflict_columns: list[str],
|
||||
update_columns: list[str] | None = None,
|
||||
) -> None:
|
||||
"""Upsert a row with automatic access control attribute capture."""
|
||||
current_user = get_authenticated_user()
|
||||
enhanced_data = _enhance_item_with_access_control(data, current_user)
|
||||
await self.sql_store.upsert(
|
||||
table=table,
|
||||
data=enhanced_data,
|
||||
conflict_columns=conflict_columns,
|
||||
update_columns=update_columns,
|
||||
)
|
||||
|
||||
async def fetch_all(
|
||||
self,
|
||||
table: str,
|
||||
|
|
|
|||
|
|
@ -72,13 +72,14 @@ def _build_where_expr(column: ColumnElement, value: Any) -> ColumnElement:
|
|||
class SqlAlchemySqlStoreImpl(SqlStore):
|
||||
def __init__(self, config: SqlAlchemySqlStoreConfig):
|
||||
self.config = config
|
||||
self._is_sqlite_backend = "sqlite" in self.config.engine_str
|
||||
self.async_session = async_sessionmaker(self.create_engine())
|
||||
self.metadata = MetaData()
|
||||
|
||||
def create_engine(self) -> AsyncEngine:
|
||||
# Configure connection args for better concurrency support
|
||||
connect_args = {}
|
||||
if "sqlite" in self.config.engine_str:
|
||||
if self._is_sqlite_backend:
|
||||
# SQLite-specific optimizations for concurrent access
|
||||
# With WAL mode, most locks resolve in milliseconds, but allow up to 5s for edge cases
|
||||
connect_args["timeout"] = 5.0
|
||||
|
|
@ -91,7 +92,7 @@ class SqlAlchemySqlStoreImpl(SqlStore):
|
|||
)
|
||||
|
||||
# Enable WAL mode for SQLite to support concurrent readers and writers
|
||||
if "sqlite" in self.config.engine_str:
|
||||
if self._is_sqlite_backend:
|
||||
|
||||
@event.listens_for(engine.sync_engine, "connect")
|
||||
def set_sqlite_pragma(dbapi_conn, connection_record):
|
||||
|
|
@ -151,6 +152,29 @@ class SqlAlchemySqlStoreImpl(SqlStore):
|
|||
await session.execute(self.metadata.tables[table].insert(), data)
|
||||
await session.commit()
|
||||
|
||||
async def upsert(
|
||||
self,
|
||||
table: str,
|
||||
data: Mapping[str, Any],
|
||||
conflict_columns: list[str],
|
||||
update_columns: list[str] | None = None,
|
||||
) -> None:
|
||||
table_obj = self.metadata.tables[table]
|
||||
dialect_insert = self._get_dialect_insert(table_obj)
|
||||
insert_stmt = dialect_insert.values(**data)
|
||||
|
||||
if update_columns is None:
|
||||
update_columns = [col for col in data.keys() if col not in conflict_columns]
|
||||
|
||||
update_mapping = {col: getattr(insert_stmt.excluded, col) for col in update_columns}
|
||||
conflict_cols = [table_obj.c[col] for col in conflict_columns]
|
||||
|
||||
stmt = insert_stmt.on_conflict_do_update(index_elements=conflict_cols, set_=update_mapping)
|
||||
|
||||
async with self.async_session() as session:
|
||||
await session.execute(stmt)
|
||||
await session.commit()
|
||||
|
||||
async def fetch_all(
|
||||
self,
|
||||
table: str,
|
||||
|
|
@ -333,9 +357,18 @@ class SqlAlchemySqlStoreImpl(SqlStore):
|
|||
add_column_sql = text(f"ALTER TABLE {table} ADD COLUMN {column_name} {compiled_type}{nullable_clause}")
|
||||
|
||||
await conn.execute(add_column_sql)
|
||||
|
||||
except Exception as e:
|
||||
# If any error occurs during migration, log it but don't fail
|
||||
# The table creation will handle adding the column
|
||||
logger.error(f"Error adding column {column_name} to table {table}: {e}")
|
||||
pass
|
||||
|
||||
def _get_dialect_insert(self, table: Table):
|
||||
if self._is_sqlite_backend:
|
||||
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
|
||||
|
||||
return sqlite_insert(table)
|
||||
else:
|
||||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||
|
||||
return pg_insert(table)
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from threading import Lock
|
||||
from typing import Annotated, cast
|
||||
|
||||
from pydantic import Field
|
||||
|
|
@ -21,6 +22,8 @@ from .api import SqlStore
|
|||
sql_store_pip_packages = ["sqlalchemy[asyncio]", "aiosqlite", "asyncpg"]
|
||||
|
||||
_SQLSTORE_BACKENDS: dict[str, StorageBackendConfig] = {}
|
||||
_SQLSTORE_INSTANCES: dict[str, SqlStore] = {}
|
||||
_SQLSTORE_LOCKS: dict[str, Lock] = {}
|
||||
|
||||
|
||||
SqlStoreConfig = Annotated[
|
||||
|
|
@ -52,11 +55,23 @@ def sqlstore_impl(reference: SqlStoreReference) -> SqlStore:
|
|||
f"Unknown SQL store backend '{backend_name}'. Registered backends: {sorted(_SQLSTORE_BACKENDS)}"
|
||||
)
|
||||
|
||||
existing = _SQLSTORE_INSTANCES.get(backend_name)
|
||||
if existing:
|
||||
return existing
|
||||
|
||||
lock = _SQLSTORE_LOCKS.setdefault(backend_name, Lock())
|
||||
with lock:
|
||||
existing = _SQLSTORE_INSTANCES.get(backend_name)
|
||||
if existing:
|
||||
return existing
|
||||
|
||||
if isinstance(backend_config, SqliteSqlStoreConfig | PostgresSqlStoreConfig):
|
||||
from .sqlalchemy_sqlstore import SqlAlchemySqlStoreImpl
|
||||
|
||||
config = cast(SqliteSqlStoreConfig | PostgresSqlStoreConfig, backend_config).model_copy()
|
||||
return SqlAlchemySqlStoreImpl(config)
|
||||
instance = SqlAlchemySqlStoreImpl(config)
|
||||
_SQLSTORE_INSTANCES[backend_name] = instance
|
||||
return instance
|
||||
else:
|
||||
raise ValueError(f"Unknown sqlstore type {backend_config.type}")
|
||||
|
||||
|
|
@ -64,7 +79,10 @@ def sqlstore_impl(reference: SqlStoreReference) -> SqlStore:
|
|||
def register_sqlstore_backends(backends: dict[str, StorageBackendConfig]) -> None:
|
||||
"""Register the set of available SQL store backends for reference resolution."""
|
||||
global _SQLSTORE_BACKENDS
|
||||
global _SQLSTORE_INSTANCES
|
||||
|
||||
_SQLSTORE_BACKENDS.clear()
|
||||
_SQLSTORE_INSTANCES.clear()
|
||||
_SQLSTORE_LOCKS.clear()
|
||||
for name, cfg in backends.items():
|
||||
_SQLSTORE_BACKENDS[name] = cfg
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ from tempfile import TemporaryDirectory
|
|||
|
||||
import pytest
|
||||
|
||||
from llama_stack.providers.utils.sqlstore.api import ColumnType
|
||||
from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType
|
||||
from llama_stack.providers.utils.sqlstore.sqlalchemy_sqlstore import SqlAlchemySqlStoreImpl
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
|
||||
|
||||
|
|
@ -65,6 +65,38 @@ async def test_sqlite_sqlstore():
|
|||
assert result.has_more is False
|
||||
|
||||
|
||||
async def test_sqlstore_upsert_support():
|
||||
with TemporaryDirectory() as tmp_dir:
|
||||
db_path = tmp_dir + "/upsert.db"
|
||||
store = SqlAlchemySqlStoreImpl(SqliteSqlStoreConfig(db_path=db_path))
|
||||
|
||||
await store.create_table(
|
||||
"items",
|
||||
{
|
||||
"id": ColumnDefinition(type=ColumnType.STRING, primary_key=True),
|
||||
"value": ColumnType.STRING,
|
||||
"updated_at": ColumnType.INTEGER,
|
||||
},
|
||||
)
|
||||
|
||||
await store.upsert(
|
||||
table="items",
|
||||
data={"id": "item_1", "value": "first", "updated_at": 1},
|
||||
conflict_columns=["id"],
|
||||
)
|
||||
row = await store.fetch_one("items", {"id": "item_1"})
|
||||
assert row == {"id": "item_1", "value": "first", "updated_at": 1}
|
||||
|
||||
await store.upsert(
|
||||
table="items",
|
||||
data={"id": "item_1", "value": "second", "updated_at": 2},
|
||||
conflict_columns=["id"],
|
||||
update_columns=["value", "updated_at"],
|
||||
)
|
||||
row = await store.fetch_one("items", {"id": "item_1"})
|
||||
assert row == {"id": "item_1", "value": "second", "updated_at": 2}
|
||||
|
||||
|
||||
async def test_sqlstore_pagination_basic():
|
||||
"""Test basic pagination functionality at the SQL store level."""
|
||||
with TemporaryDirectory() as tmp_dir:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue