mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-12 04:50:39 +00:00
change self.connection to always get_connection
This commit is contained in:
parent
2868d8f793
commit
f037510f3a
1 changed files with 15 additions and 17 deletions
|
@ -23,7 +23,6 @@ from llama_stack.providers.utils.memory.vector_store import EmbeddingIndex, Vect
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def serialize_vector(vector: List[float]) -> bytes:
|
def serialize_vector(vector: List[float]) -> bytes:
|
||||||
"""Serialize a list of floats into a compact binary representation."""
|
"""Serialize a list of floats into a compact binary representation."""
|
||||||
return struct.pack(f"{len(vector)}f", *vector)
|
return struct.pack(f"{len(vector)}f", *vector)
|
||||||
|
@ -148,6 +147,7 @@ class SQLiteVecIndex(EmbeddingIndex):
|
||||||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class SQLiteVecVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
class SQLiteVecVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
||||||
"""
|
"""
|
||||||
A VectorIO implementation using SQLite + sqlite_vec.
|
A VectorIO implementation using SQLite + sqlite_vec.
|
||||||
|
@ -173,11 +173,11 @@ class SQLiteVecVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
# Open a connection to the SQLite database (the file is specified in the config).
|
# Open a connection to the SQLite database (the file is specified in the config).
|
||||||
self.connection = self._get_connection()
|
connection = self._get_connection()
|
||||||
self.connection.enable_load_extension(True)
|
connection.enable_load_extension(True)
|
||||||
sqlite_vec.load(self.connection)
|
sqlite_vec.load(connection)
|
||||||
self.connection.enable_load_extension(False)
|
connection.enable_load_extension(False)
|
||||||
cur = self.connection.cursor()
|
cur = connection.cursor()
|
||||||
# Create a table to persist vector DB registrations.
|
# Create a table to persist vector DB registrations.
|
||||||
cur.execute("""
|
cur.execute("""
|
||||||
CREATE TABLE IF NOT EXISTS vector_dbs (
|
CREATE TABLE IF NOT EXISTS vector_dbs (
|
||||||
|
@ -185,14 +185,14 @@ class SQLiteVecVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
||||||
metadata TEXT
|
metadata TEXT
|
||||||
);
|
);
|
||||||
""")
|
""")
|
||||||
self.connection.commit()
|
connection.commit()
|
||||||
# Load any existing vector DB registrations.
|
# Load any existing vector DB registrations.
|
||||||
cur.execute("SELECT metadata FROM vector_dbs")
|
cur.execute("SELECT metadata FROM vector_dbs")
|
||||||
rows = cur.fetchall()
|
rows = cur.fetchall()
|
||||||
for row in rows:
|
for row in rows:
|
||||||
vector_db_data = row[0]
|
vector_db_data = row[0]
|
||||||
vector_db = VectorDB.model_validate_json(vector_db_data)
|
vector_db = VectorDB.model_validate_json(vector_db_data)
|
||||||
index = await SQLiteVecIndex.create(vector_db.embedding_dimension, self.connection, vector_db.identifier)
|
index = await SQLiteVecIndex.create(vector_db.embedding_dimension, connection, vector_db.identifier)
|
||||||
self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api)
|
self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api)
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
|
@ -206,31 +206,29 @@ class SQLiteVecVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
||||||
del self._local.conn
|
del self._local.conn
|
||||||
|
|
||||||
async def register_vector_db(self, vector_db: VectorDB) -> None:
|
async def register_vector_db(self, vector_db: VectorDB) -> None:
|
||||||
if self.connection is None:
|
connection = self._get_connection()
|
||||||
raise RuntimeError("SQLite connection not initialized")
|
cur = connection.cursor()
|
||||||
cur = self.connection.cursor()
|
|
||||||
cur.execute(
|
cur.execute(
|
||||||
"INSERT OR REPLACE INTO vector_dbs (id, metadata) VALUES (?, ?)",
|
"INSERT OR REPLACE INTO vector_dbs (id, metadata) VALUES (?, ?)",
|
||||||
(vector_db.identifier, vector_db.model_dump_json()),
|
(vector_db.identifier, vector_db.model_dump_json()),
|
||||||
)
|
)
|
||||||
self.connection.commit()
|
connection.commit()
|
||||||
index = await SQLiteVecIndex.create(vector_db.embedding_dimension, self.connection, vector_db.identifier)
|
index = await SQLiteVecIndex.create(vector_db.embedding_dimension, connection, vector_db.identifier)
|
||||||
self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api)
|
self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api)
|
||||||
|
|
||||||
async def list_vector_dbs(self) -> List[VectorDB]:
|
async def list_vector_dbs(self) -> List[VectorDB]:
|
||||||
return [v.vector_db for v in self.cache.values()]
|
return [v.vector_db for v in self.cache.values()]
|
||||||
|
|
||||||
async def unregister_vector_db(self, vector_db_id: str) -> None:
|
async def unregister_vector_db(self, vector_db_id: str) -> None:
|
||||||
if self.connection is None:
|
connection = self._get_connection()
|
||||||
raise RuntimeError("SQLite connection not initialized")
|
|
||||||
if vector_db_id not in self.cache:
|
if vector_db_id not in self.cache:
|
||||||
logger.warning(f"Vector DB {vector_db_id} not found")
|
logger.warning(f"Vector DB {vector_db_id} not found")
|
||||||
return
|
return
|
||||||
await self.cache[vector_db_id].index.delete()
|
await self.cache[vector_db_id].index.delete()
|
||||||
del self.cache[vector_db_id]
|
del self.cache[vector_db_id]
|
||||||
cur = self.connection.cursor()
|
cur = connection.cursor()
|
||||||
cur.execute("DELETE FROM vector_dbs WHERE id = ?", (vector_db_id,))
|
cur.execute("DELETE FROM vector_dbs WHERE id = ?", (vector_db_id,))
|
||||||
self.connection.commit()
|
connection.commit()
|
||||||
|
|
||||||
async def insert_chunks(self, vector_db_id: str, chunks: List[Chunk], ttl_seconds: Optional[int] = None) -> None:
|
async def insert_chunks(self, vector_db_id: str, chunks: List[Chunk], ttl_seconds: Optional[int] = None) -> None:
|
||||||
if vector_db_id not in self.cache:
|
if vector_db_id not in self.cache:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue