From 106c9b0caa0613060eca35c032fe9aad83c357c6 Mon Sep 17 00:00:00 2001 From: Raghotham Murthy Date: Sat, 4 Oct 2025 13:05:51 -0700 Subject: [PATCH] feat: Allow :memory: for kvstore --- .../providers/utils/kvstore/sqlite/sqlite.py | 113 ++++++++++------- .../unit/utils/kvstore/test_sqlite_memory.py | 116 ++++++++++++++++++ 2 files changed, 182 insertions(+), 47 deletions(-) create mode 100644 tests/unit/utils/kvstore/test_sqlite_memory.py diff --git a/llama_stack/providers/utils/kvstore/sqlite/sqlite.py b/llama_stack/providers/utils/kvstore/sqlite/sqlite.py index 5b782902e..e1c0332fe 100644 --- a/llama_stack/providers/utils/kvstore/sqlite/sqlite.py +++ b/llama_stack/providers/utils/kvstore/sqlite/sqlite.py @@ -21,67 +21,86 @@ class SqliteKVStoreImpl(KVStore): def __init__(self, config: SqliteKVStoreConfig): self.db_path = config.db_path self.table_name = "kvstore" + self._conn: aiosqlite.Connection | None = None def __str__(self): return f"SqliteKVStoreImpl(db_path={self.db_path}, table_name={self.table_name})" + def _is_memory_db(self) -> bool: + """Check if this is an in-memory database.""" + return self.db_path == ":memory:" or "mode=memory" in self.db_path + async def initialize(self): - os.makedirs(os.path.dirname(self.db_path), exist_ok=True) - async with aiosqlite.connect(self.db_path) as db: - await db.execute( - f""" - CREATE TABLE IF NOT EXISTS {self.table_name} ( - key TEXT PRIMARY KEY, - value TEXT, - expiration TIMESTAMP - ) - """ + # Skip directory creation for in-memory databases and file: URIs + if not self._is_memory_db() and not self.db_path.startswith("file:"): + db_dir = os.path.dirname(self.db_path) + if db_dir: # Only create if there's a directory component + os.makedirs(db_dir, exist_ok=True) + + # Create persistent connection for all databases + self._conn = await aiosqlite.connect(self.db_path) + await self._conn.execute( + f""" + CREATE TABLE IF NOT EXISTS {self.table_name} ( + key TEXT PRIMARY KEY, + value TEXT, + expiration TIMESTAMP ) - await db.commit() + """ + ) + await self._conn.commit() + + async def close(self): + """Close the persistent connection.""" + if self._conn: + await self._conn.close() + self._conn = None async def set(self, key: str, value: str, expiration: datetime | None = None) -> None: - async with aiosqlite.connect(self.db_path) as db: - await db.execute( - f"INSERT OR REPLACE INTO {self.table_name} (key, value, expiration) VALUES (?, ?, ?)", - (key, value, expiration), - ) - await db.commit() + assert self._conn is not None, "Connection not initialized. Call initialize() first." + await self._conn.execute( + f"INSERT OR REPLACE INTO {self.table_name} (key, value, expiration) VALUES (?, ?, ?)", + (key, value, expiration), + ) + await self._conn.commit() async def get(self, key: str) -> str | None: - async with aiosqlite.connect(self.db_path) as db: - async with db.execute(f"SELECT value, expiration FROM {self.table_name} WHERE key = ?", (key,)) as cursor: - row = await cursor.fetchone() - if row is None: - return None - value, expiration = row - if not isinstance(value, str): - logger.warning(f"Expected string value for key {key}, got {type(value)}, returning None") - return None - return value + assert self._conn is not None, "Connection not initialized. Call initialize() first." + async with self._conn.execute( + f"SELECT value, expiration FROM {self.table_name} WHERE key = ?", (key,) + ) as cursor: + row = await cursor.fetchone() + if row is None: + return None + value, expiration = row + if not isinstance(value, str): + logger.warning(f"Expected string value for key {key}, got {type(value)}, returning None") + return None + return value async def delete(self, key: str) -> None: - async with aiosqlite.connect(self.db_path) as db: - await db.execute(f"DELETE FROM {self.table_name} WHERE key = ?", (key,)) - await db.commit() + assert self._conn is not None, "Connection not initialized. Call initialize() first." + await self._conn.execute(f"DELETE FROM {self.table_name} WHERE key = ?", (key,)) + await self._conn.commit() async def values_in_range(self, start_key: str, end_key: str) -> list[str]: - async with aiosqlite.connect(self.db_path) as db: - async with db.execute( - f"SELECT key, value, expiration FROM {self.table_name} WHERE key >= ? AND key <= ?", - (start_key, end_key), - ) as cursor: - result = [] - async for row in cursor: - _, value, _ = row - result.append(value) - return result + assert self._conn is not None, "Connection not initialized. Call initialize() first." + async with self._conn.execute( + f"SELECT key, value, expiration FROM {self.table_name} WHERE key >= ? AND key <= ?", + (start_key, end_key), + ) as cursor: + result = [] + async for row in cursor: + _, value, _ = row + result.append(value) + return result async def keys_in_range(self, start_key: str, end_key: str) -> list[str]: """Get all keys in the given range.""" - async with aiosqlite.connect(self.db_path) as db: - cursor = await db.execute( - f"SELECT key FROM {self.table_name} WHERE key >= ? AND key <= ?", - (start_key, end_key), - ) - rows = await cursor.fetchall() - return [row[0] for row in rows] + assert self._conn is not None, "Connection not initialized. Call initialize() first." + cursor = await self._conn.execute( + f"SELECT key FROM {self.table_name} WHERE key >= ? AND key <= ?", + (start_key, end_key), + ) + rows = await cursor.fetchall() + return [row[0] for row in rows] diff --git a/tests/unit/utils/kvstore/test_sqlite_memory.py b/tests/unit/utils/kvstore/test_sqlite_memory.py new file mode 100644 index 000000000..5c9d4c4ff --- /dev/null +++ b/tests/unit/utils/kvstore/test_sqlite_memory.py @@ -0,0 +1,116 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import pytest + +from llama_stack.providers.utils.kvstore.sqlite.config import SqliteKVStoreConfig +from llama_stack.providers.utils.kvstore.sqlite.sqlite import SqliteKVStoreImpl + + +async def test_memory_kvstore_basic_operations(): + """Test basic CRUD operations with :memory: database.""" + config = SqliteKVStoreConfig(db_path=":memory:") + store = SqliteKVStoreImpl(config) + await store.initialize() + + # Test set and get + await store.set("key1", "value1") + result = await store.get("key1") + assert result == "value1" + + # Test get non-existent key + result = await store.get("nonexistent") + assert result is None + + # Test update + await store.set("key1", "updated_value") + result = await store.get("key1") + assert result == "updated_value" + + # Test delete + await store.delete("key1") + result = await store.get("key1") + assert result is None + + await store.close() + + +async def test_memory_kvstore_range_operations(): + """Test range query operations with :memory: database.""" + config = SqliteKVStoreConfig(db_path=":memory:") + store = SqliteKVStoreImpl(config) + await store.initialize() + + # Set up test data + await store.set("key_a", "value_a") + await store.set("key_b", "value_b") + await store.set("key_c", "value_c") + await store.set("key_d", "value_d") + + # Test values_in_range + values = await store.values_in_range("key_b", "key_c") + assert len(values) == 2 + assert "value_b" in values + assert "value_c" in values + + # Test keys_in_range + keys = await store.keys_in_range("key_a", "key_c") + assert len(keys) == 3 + assert "key_a" in keys + assert "key_b" in keys + assert "key_c" in keys + + await store.close() + + +async def test_memory_kvstore_multiple_instances(): + """Test that multiple :memory: instances are independent.""" + config1 = SqliteKVStoreConfig(db_path=":memory:") + config2 = SqliteKVStoreConfig(db_path=":memory:") + + store1 = SqliteKVStoreImpl(config1) + store2 = SqliteKVStoreImpl(config2) + + await store1.initialize() + await store2.initialize() + + # Set data in store1 + await store1.set("shared_key", "value_from_store1") + + # Verify store2 doesn't see store1's data + result = await store2.get("shared_key") + assert result is None + + # Set different value in store2 + await store2.set("shared_key", "value_from_store2") + + # Verify both stores have independent data + assert await store1.get("shared_key") == "value_from_store1" + assert await store2.get("shared_key") == "value_from_store2" + + await store1.close() + await store2.close() + + +async def test_memory_kvstore_persistence_behavior(): + """Test that :memory: database doesn't persist across instances.""" + config = SqliteKVStoreConfig(db_path=":memory:") + + # First instance + store1 = SqliteKVStoreImpl(config) + await store1.initialize() + await store1.set("persist_test", "should_not_persist") + await store1.close() + + # Second instance with same config + store2 = SqliteKVStoreImpl(config) + await store2.initialize() + + # Data should not be present + result = await store2.get("persist_test") + assert result is None + + await store2.close()