# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. import os from datetime import datetime import aiosqlite from llama_stack.log import get_logger from ..api import KVStore from ..config import SqliteKVStoreConfig logger = get_logger(name=__name__, category="providers::utils") class SqliteKVStoreImpl(KVStore): def __init__(self, config: SqliteKVStoreConfig): self.db_path = config.db_path self.table_name = "kvstore" self._conn = None def __str__(self): return f"SqliteKVStoreImpl(db_path={self.db_path}, table_name={self.table_name})" def _is_memory_db(self) -> bool: """Check if this is an in-memory database.""" return self.db_path == ":memory:" or "mode=memory" in self.db_path async def initialize(self): # Skip directory creation for in-memory databases and file: URIs if not self._is_memory_db() and not self.db_path.startswith("file:"): db_dir = os.path.dirname(self.db_path) if db_dir: # Only create if there's a directory component os.makedirs(db_dir, exist_ok=True) # Create persistent connection for all databases self._conn = await aiosqlite.connect(self.db_path) await self._conn.execute( f""" CREATE TABLE IF NOT EXISTS {self.table_name} ( key TEXT PRIMARY KEY, value TEXT, expiration TIMESTAMP ) """ ) await self._conn.commit() async def close(self): """Close the persistent connection.""" if self._conn: await self._conn.close() self._conn = None async def set(self, key: str, value: str, expiration: datetime | None = None) -> None: await self._conn.execute( f"INSERT OR REPLACE INTO {self.table_name} (key, value, expiration) VALUES (?, ?, ?)", (key, value, expiration), ) await self._conn.commit() async def get(self, key: str) -> str | None: async with self._conn.execute(f"SELECT value, expiration FROM {self.table_name} WHERE key = ?", (key,)) as cursor: row = await cursor.fetchone() if row is None: return None value, expiration = row if not isinstance(value, str): logger.warning(f"Expected string value for key {key}, got {type(value)}, returning None") return None return value async def delete(self, key: str) -> None: await self._conn.execute(f"DELETE FROM {self.table_name} WHERE key = ?", (key,)) await self._conn.commit() async def values_in_range(self, start_key: str, end_key: str) -> list[str]: async with self._conn.execute( f"SELECT key, value, expiration FROM {self.table_name} WHERE key >= ? AND key <= ?", (start_key, end_key), ) as cursor: result = [] async for row in cursor: _, value, _ = row result.append(value) return result async def keys_in_range(self, start_key: str, end_key: str) -> list[str]: """Get all keys in the given range.""" cursor = await self._conn.execute( f"SELECT key FROM {self.table_name} WHERE key >= ? AND key <= ?", (start_key, end_key), ) rows = await cursor.fetchall() return [row[0] for row in rows]