mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-12 04:50:39 +00:00
change self.connection to always get_connection
This commit is contained in:
parent
2868d8f793
commit
f037510f3a
1 changed files with 15 additions and 17 deletions
|
@ -23,7 +23,6 @@ from llama_stack.providers.utils.memory.vector_store import EmbeddingIndex, Vect
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def serialize_vector(vector: List[float]) -> bytes:
|
||||
"""Serialize a list of floats into a compact binary representation."""
|
||||
return struct.pack(f"{len(vector)}f", *vector)
|
||||
|
@ -148,6 +147,7 @@ class SQLiteVecIndex(EmbeddingIndex):
|
|||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||
|
||||
|
||||
|
||||
class SQLiteVecVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
||||
"""
|
||||
A VectorIO implementation using SQLite + sqlite_vec.
|
||||
|
@ -173,11 +173,11 @@ class SQLiteVecVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
|||
|
||||
async def initialize(self) -> None:
|
||||
# Open a connection to the SQLite database (the file is specified in the config).
|
||||
self.connection = self._get_connection()
|
||||
self.connection.enable_load_extension(True)
|
||||
sqlite_vec.load(self.connection)
|
||||
self.connection.enable_load_extension(False)
|
||||
cur = self.connection.cursor()
|
||||
connection = self._get_connection()
|
||||
connection.enable_load_extension(True)
|
||||
sqlite_vec.load(connection)
|
||||
connection.enable_load_extension(False)
|
||||
cur = connection.cursor()
|
||||
# Create a table to persist vector DB registrations.
|
||||
cur.execute("""
|
||||
CREATE TABLE IF NOT EXISTS vector_dbs (
|
||||
|
@ -185,14 +185,14 @@ class SQLiteVecVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
|||
metadata TEXT
|
||||
);
|
||||
""")
|
||||
self.connection.commit()
|
||||
connection.commit()
|
||||
# Load any existing vector DB registrations.
|
||||
cur.execute("SELECT metadata FROM vector_dbs")
|
||||
rows = cur.fetchall()
|
||||
for row in rows:
|
||||
vector_db_data = row[0]
|
||||
vector_db = VectorDB.model_validate_json(vector_db_data)
|
||||
index = await SQLiteVecIndex.create(vector_db.embedding_dimension, self.connection, vector_db.identifier)
|
||||
index = await SQLiteVecIndex.create(vector_db.embedding_dimension, connection, vector_db.identifier)
|
||||
self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api)
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
|
@ -206,31 +206,29 @@ class SQLiteVecVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
|||
del self._local.conn
|
||||
|
||||
async def register_vector_db(self, vector_db: VectorDB) -> None:
|
||||
if self.connection is None:
|
||||
raise RuntimeError("SQLite connection not initialized")
|
||||
cur = self.connection.cursor()
|
||||
connection = self._get_connection()
|
||||
cur = connection.cursor()
|
||||
cur.execute(
|
||||
"INSERT OR REPLACE INTO vector_dbs (id, metadata) VALUES (?, ?)",
|
||||
(vector_db.identifier, vector_db.model_dump_json()),
|
||||
)
|
||||
self.connection.commit()
|
||||
index = await SQLiteVecIndex.create(vector_db.embedding_dimension, self.connection, vector_db.identifier)
|
||||
connection.commit()
|
||||
index = await SQLiteVecIndex.create(vector_db.embedding_dimension, connection, vector_db.identifier)
|
||||
self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api)
|
||||
|
||||
async def list_vector_dbs(self) -> List[VectorDB]:
|
||||
return [v.vector_db for v in self.cache.values()]
|
||||
|
||||
async def unregister_vector_db(self, vector_db_id: str) -> None:
|
||||
if self.connection is None:
|
||||
raise RuntimeError("SQLite connection not initialized")
|
||||
connection = self._get_connection()
|
||||
if vector_db_id not in self.cache:
|
||||
logger.warning(f"Vector DB {vector_db_id} not found")
|
||||
return
|
||||
await self.cache[vector_db_id].index.delete()
|
||||
del self.cache[vector_db_id]
|
||||
cur = self.connection.cursor()
|
||||
cur = connection.cursor()
|
||||
cur.execute("DELETE FROM vector_dbs WHERE id = ?", (vector_db_id,))
|
||||
self.connection.commit()
|
||||
connection.commit()
|
||||
|
||||
async def insert_chunks(self, vector_db_id: str, chunks: List[Chunk], ttl_seconds: Optional[int] = None) -> None:
|
||||
if vector_db_id not in self.cache:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue