From f037510f3a635c5395f8cd555d823bee9aff5e3c Mon Sep 17 00:00:00 2001 From: Kai Wu Date: Mon, 3 Mar 2025 14:53:25 -0800 Subject: [PATCH] change self.connection to always get_connection --- .../inline/vector_io/sqlite_vec/sqlite_vec.py | 32 +++++++++---------- 1 file changed, 15 insertions(+), 17 deletions(-) diff --git a/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py b/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py index 347d7ab2b..102cad8e2 100644 --- a/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py +++ b/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py @@ -23,7 +23,6 @@ from llama_stack.providers.utils.memory.vector_store import EmbeddingIndex, Vect logger = logging.getLogger(__name__) - def serialize_vector(vector: List[float]) -> bytes: """Serialize a list of floats into a compact binary representation.""" return struct.pack(f"{len(vector)}f", *vector) @@ -148,6 +147,7 @@ class SQLiteVecIndex(EmbeddingIndex): return QueryChunksResponse(chunks=chunks, scores=scores) + class SQLiteVecVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): """ A VectorIO implementation using SQLite + sqlite_vec. @@ -173,11 +173,11 @@ class SQLiteVecVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): async def initialize(self) -> None: # Open a connection to the SQLite database (the file is specified in the config). - self.connection = self._get_connection() - self.connection.enable_load_extension(True) - sqlite_vec.load(self.connection) - self.connection.enable_load_extension(False) - cur = self.connection.cursor() + connection = self._get_connection() + connection.enable_load_extension(True) + sqlite_vec.load(connection) + connection.enable_load_extension(False) + cur = connection.cursor() # Create a table to persist vector DB registrations. cur.execute(""" CREATE TABLE IF NOT EXISTS vector_dbs ( @@ -185,14 +185,14 @@ class SQLiteVecVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): metadata TEXT ); """) - self.connection.commit() + connection.commit() # Load any existing vector DB registrations. cur.execute("SELECT metadata FROM vector_dbs") rows = cur.fetchall() for row in rows: vector_db_data = row[0] vector_db = VectorDB.model_validate_json(vector_db_data) - index = await SQLiteVecIndex.create(vector_db.embedding_dimension, self.connection, vector_db.identifier) + index = await SQLiteVecIndex.create(vector_db.embedding_dimension, connection, vector_db.identifier) self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api) async def shutdown(self) -> None: @@ -206,31 +206,29 @@ class SQLiteVecVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): del self._local.conn async def register_vector_db(self, vector_db: VectorDB) -> None: - if self.connection is None: - raise RuntimeError("SQLite connection not initialized") - cur = self.connection.cursor() + connection = self._get_connection() + cur = connection.cursor() cur.execute( "INSERT OR REPLACE INTO vector_dbs (id, metadata) VALUES (?, ?)", (vector_db.identifier, vector_db.model_dump_json()), ) - self.connection.commit() - index = await SQLiteVecIndex.create(vector_db.embedding_dimension, self.connection, vector_db.identifier) + connection.commit() + index = await SQLiteVecIndex.create(vector_db.embedding_dimension, connection, vector_db.identifier) self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api) async def list_vector_dbs(self) -> List[VectorDB]: return [v.vector_db for v in self.cache.values()] async def unregister_vector_db(self, vector_db_id: str) -> None: - if self.connection is None: - raise RuntimeError("SQLite connection not initialized") + connection = self._get_connection() if vector_db_id not in self.cache: logger.warning(f"Vector DB {vector_db_id} not found") return await self.cache[vector_db_id].index.delete() del self.cache[vector_db_id] - cur = self.connection.cursor() + cur = connection.cursor() cur.execute("DELETE FROM vector_dbs WHERE id = ?", (vector_db_id,)) - self.connection.commit() + connection.commit() async def insert_chunks(self, vector_db_id: str, chunks: List[Chunk], ttl_seconds: Optional[int] = None) -> None: if vector_db_id not in self.cache: