From b95f095a54ce1928ba9ead4abe3eceb99ffa4f5d Mon Sep 17 00:00:00 2001 From: raghotham Date: Mon, 13 Oct 2025 02:19:27 -0700 Subject: [PATCH] feat: Allow :memory: for kvstore (#3696) ## Test Plan added unit tests --- .../providers/utils/kvstore/sqlite/sqlite.py | 119 +++++++++++++++--- .../unit/utils/kvstore/test_sqlite_memory.py | 30 +++++ 2 files changed, 133 insertions(+), 16 deletions(-) create mode 100644 tests/unit/utils/kvstore/test_sqlite_memory.py diff --git a/llama_stack/providers/utils/kvstore/sqlite/sqlite.py b/llama_stack/providers/utils/kvstore/sqlite/sqlite.py index 5b782902e..a9a7a1304 100644 --- a/llama_stack/providers/utils/kvstore/sqlite/sqlite.py +++ b/llama_stack/providers/utils/kvstore/sqlite/sqlite.py @@ -21,14 +21,27 @@ class SqliteKVStoreImpl(KVStore): def __init__(self, config: SqliteKVStoreConfig): self.db_path = config.db_path self.table_name = "kvstore" + self._conn: aiosqlite.Connection | None = None def __str__(self): return f"SqliteKVStoreImpl(db_path={self.db_path}, table_name={self.table_name})" + def _is_memory_db(self) -> bool: + """Check if this is an in-memory database.""" + return self.db_path == ":memory:" or "mode=memory" in self.db_path + async def initialize(self): - os.makedirs(os.path.dirname(self.db_path), exist_ok=True) - async with aiosqlite.connect(self.db_path) as db: - await db.execute( + # Skip directory creation for in-memory databases and file: URIs + if not self._is_memory_db() and not self.db_path.startswith("file:"): + db_dir = os.path.dirname(self.db_path) + if db_dir: # Only create if there's a directory component + os.makedirs(db_dir, exist_ok=True) + + # Only use persistent connection for in-memory databases + # File-based databases use connection-per-operation to avoid hangs + if self._is_memory_db(): + self._conn = await aiosqlite.connect(self.db_path) + await self._conn.execute( f""" CREATE TABLE IF NOT EXISTS {self.table_name} ( key TEXT PRIMARY KEY, @@ -37,19 +50,50 @@ class SqliteKVStoreImpl(KVStore): ) """ ) - await db.commit() + await self._conn.commit() + else: + # For file-based databases, just create the table + async with aiosqlite.connect(self.db_path) as db: + await db.execute( + f""" + CREATE TABLE IF NOT EXISTS {self.table_name} ( + key TEXT PRIMARY KEY, + value TEXT, + expiration TIMESTAMP + ) + """ + ) + await db.commit() + + async def shutdown(self): + """Close the persistent connection (only for in-memory databases).""" + if self._conn: + await self._conn.close() + self._conn = None async def set(self, key: str, value: str, expiration: datetime | None = None) -> None: - async with aiosqlite.connect(self.db_path) as db: - await db.execute( + if self._conn: + # In-memory database with persistent connection + await self._conn.execute( f"INSERT OR REPLACE INTO {self.table_name} (key, value, expiration) VALUES (?, ?, ?)", (key, value, expiration), ) - await db.commit() + await self._conn.commit() + else: + # File-based database with connection per operation + async with aiosqlite.connect(self.db_path) as db: + await db.execute( + f"INSERT OR REPLACE INTO {self.table_name} (key, value, expiration) VALUES (?, ?, ?)", + (key, value, expiration), + ) + await db.commit() async def get(self, key: str) -> str | None: - async with aiosqlite.connect(self.db_path) as db: - async with db.execute(f"SELECT value, expiration FROM {self.table_name} WHERE key = ?", (key,)) as cursor: + if self._conn: + # In-memory database with persistent connection + async with self._conn.execute( + f"SELECT value, expiration FROM {self.table_name} WHERE key = ?", (key,) + ) as cursor: row = await cursor.fetchone() if row is None: return None @@ -58,15 +102,36 @@ class SqliteKVStoreImpl(KVStore): logger.warning(f"Expected string value for key {key}, got {type(value)}, returning None") return None return value + else: + # File-based database with connection per operation + async with aiosqlite.connect(self.db_path) as db: + async with db.execute( + f"SELECT value, expiration FROM {self.table_name} WHERE key = ?", (key,) + ) as cursor: + row = await cursor.fetchone() + if row is None: + return None + value, expiration = row + if not isinstance(value, str): + logger.warning(f"Expected string value for key {key}, got {type(value)}, returning None") + return None + return value async def delete(self, key: str) -> None: - async with aiosqlite.connect(self.db_path) as db: - await db.execute(f"DELETE FROM {self.table_name} WHERE key = ?", (key,)) - await db.commit() + if self._conn: + # In-memory database with persistent connection + await self._conn.execute(f"DELETE FROM {self.table_name} WHERE key = ?", (key,)) + await self._conn.commit() + else: + # File-based database with connection per operation + async with aiosqlite.connect(self.db_path) as db: + await db.execute(f"DELETE FROM {self.table_name} WHERE key = ?", (key,)) + await db.commit() async def values_in_range(self, start_key: str, end_key: str) -> list[str]: - async with aiosqlite.connect(self.db_path) as db: - async with db.execute( + if self._conn: + # In-memory database with persistent connection + async with self._conn.execute( f"SELECT key, value, expiration FROM {self.table_name} WHERE key >= ? AND key <= ?", (start_key, end_key), ) as cursor: @@ -75,13 +140,35 @@ class SqliteKVStoreImpl(KVStore): _, value, _ = row result.append(value) return result + else: + # File-based database with connection per operation + async with aiosqlite.connect(self.db_path) as db: + async with db.execute( + f"SELECT key, value, expiration FROM {self.table_name} WHERE key >= ? AND key <= ?", + (start_key, end_key), + ) as cursor: + result = [] + async for row in cursor: + _, value, _ = row + result.append(value) + return result async def keys_in_range(self, start_key: str, end_key: str) -> list[str]: """Get all keys in the given range.""" - async with aiosqlite.connect(self.db_path) as db: - cursor = await db.execute( + if self._conn: + # In-memory database with persistent connection + cursor = await self._conn.execute( f"SELECT key FROM {self.table_name} WHERE key >= ? AND key <= ?", (start_key, end_key), ) rows = await cursor.fetchall() return [row[0] for row in rows] + else: + # File-based database with connection per operation + async with aiosqlite.connect(self.db_path) as db: + cursor = await db.execute( + f"SELECT key FROM {self.table_name} WHERE key >= ? AND key <= ?", + (start_key, end_key), + ) + rows = await cursor.fetchall() + return [row[0] for row in rows] diff --git a/tests/unit/utils/kvstore/test_sqlite_memory.py b/tests/unit/utils/kvstore/test_sqlite_memory.py new file mode 100644 index 000000000..a31377306 --- /dev/null +++ b/tests/unit/utils/kvstore/test_sqlite_memory.py @@ -0,0 +1,30 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + + +from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig +from llama_stack.providers.utils.kvstore.sqlite.sqlite import SqliteKVStoreImpl + + +async def test_memory_kvstore_persistence_behavior(): + """Test that :memory: database doesn't persist across instances.""" + config = SqliteKVStoreConfig(db_path=":memory:") + + # First instance + store1 = SqliteKVStoreImpl(config) + await store1.initialize() + await store1.set("persist_test", "should_not_persist") + await store1.shutdown() + + # Second instance with same config + store2 = SqliteKVStoreImpl(config) + await store2.initialize() + + # Data should not be present + result = await store2.get("persist_test") + assert result is None + + await store2.shutdown()