change self.connection to always get_connection

This commit is contained in:
Kai Wu 2025-03-03 14:53:25 -08:00
parent 2868d8f793
commit f037510f3a

View file

@ -23,7 +23,6 @@ from llama_stack.providers.utils.memory.vector_store import EmbeddingIndex, Vect
logger = logging.getLogger(__name__)
def serialize_vector(vector: List[float]) -> bytes:
"""Serialize a list of floats into a compact binary representation."""
return struct.pack(f"{len(vector)}f", *vector)
@ -148,6 +147,7 @@ class SQLiteVecIndex(EmbeddingIndex):
return QueryChunksResponse(chunks=chunks, scores=scores)
class SQLiteVecVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
"""
A VectorIO implementation using SQLite + sqlite_vec.
@ -173,11 +173,11 @@ class SQLiteVecVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
async def initialize(self) -> None:
# Open a connection to the SQLite database (the file is specified in the config).
self.connection = self._get_connection()
self.connection.enable_load_extension(True)
sqlite_vec.load(self.connection)
self.connection.enable_load_extension(False)
cur = self.connection.cursor()
connection = self._get_connection()
connection.enable_load_extension(True)
sqlite_vec.load(connection)
connection.enable_load_extension(False)
cur = connection.cursor()
# Create a table to persist vector DB registrations.
cur.execute("""
CREATE TABLE IF NOT EXISTS vector_dbs (
@ -185,14 +185,14 @@ class SQLiteVecVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
metadata TEXT
);
""")
self.connection.commit()
connection.commit()
# Load any existing vector DB registrations.
cur.execute("SELECT metadata FROM vector_dbs")
rows = cur.fetchall()
for row in rows:
vector_db_data = row[0]
vector_db = VectorDB.model_validate_json(vector_db_data)
index = await SQLiteVecIndex.create(vector_db.embedding_dimension, self.connection, vector_db.identifier)
index = await SQLiteVecIndex.create(vector_db.embedding_dimension, connection, vector_db.identifier)
self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api)
async def shutdown(self) -> None:
@ -206,31 +206,29 @@ class SQLiteVecVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
del self._local.conn
async def register_vector_db(self, vector_db: VectorDB) -> None:
if self.connection is None:
raise RuntimeError("SQLite connection not initialized")
cur = self.connection.cursor()
connection = self._get_connection()
cur = connection.cursor()
cur.execute(
"INSERT OR REPLACE INTO vector_dbs (id, metadata) VALUES (?, ?)",
(vector_db.identifier, vector_db.model_dump_json()),
)
self.connection.commit()
index = await SQLiteVecIndex.create(vector_db.embedding_dimension, self.connection, vector_db.identifier)
connection.commit()
index = await SQLiteVecIndex.create(vector_db.embedding_dimension, connection, vector_db.identifier)
self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api)
async def list_vector_dbs(self) -> List[VectorDB]:
return [v.vector_db for v in self.cache.values()]
async def unregister_vector_db(self, vector_db_id: str) -> None:
if self.connection is None:
raise RuntimeError("SQLite connection not initialized")
connection = self._get_connection()
if vector_db_id not in self.cache:
logger.warning(f"Vector DB {vector_db_id} not found")
return
await self.cache[vector_db_id].index.delete()
del self.cache[vector_db_id]
cur = self.connection.cursor()
cur = connection.cursor()
cur.execute("DELETE FROM vector_dbs WHERE id = ?", (vector_db_id,))
self.connection.commit()
connection.commit()
async def insert_chunks(self, vector_db_id: str, chunks: List[Chunk], ttl_seconds: Optional[int] = None) -> None:
if vector_db_id not in self.cache: