mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-27 14:38:49 +00:00
Merge 05d3fffbdf
into 21bae296f2
This commit is contained in:
commit
754fb32c59
13 changed files with 146 additions and 12 deletions
|
@ -55,6 +55,11 @@ class FaissIndex(EmbeddingIndex):
|
|||
self.kvstore = kvstore
|
||||
self.bank_id = bank_id
|
||||
|
||||
# A list of chunk id's in the same order as they are in the index,
|
||||
# must be updated when chunks are added or removed
|
||||
self.chunk_id_lock = asyncio.Lock()
|
||||
self.chunk_ids: list[Any] = []
|
||||
|
||||
@classmethod
|
||||
async def create(cls, dimension: int, kvstore: KVStore | None = None, bank_id: str | None = None):
|
||||
instance = cls(dimension, kvstore, bank_id)
|
||||
|
@ -75,6 +80,7 @@ class FaissIndex(EmbeddingIndex):
|
|||
buffer = io.BytesIO(base64.b64decode(data["faiss_index"]))
|
||||
try:
|
||||
self.index = faiss.deserialize_index(np.load(buffer, allow_pickle=False))
|
||||
self.chunk_ids = [chunk.chunk_id for chunk in self.chunk_by_index.values()]
|
||||
except Exception as e:
|
||||
logger.debug(e, exc_info=True)
|
||||
raise ValueError(
|
||||
|
@ -114,11 +120,33 @@ class FaissIndex(EmbeddingIndex):
|
|||
for i, chunk in enumerate(chunks):
|
||||
self.chunk_by_index[indexlen + i] = chunk
|
||||
|
||||
self.index.add(np.array(embeddings).astype(np.float32))
|
||||
async with self.chunk_id_lock:
|
||||
self.index.add(np.array(embeddings).astype(np.float32))
|
||||
self.chunk_ids.extend([chunk.chunk_id for chunk in chunks])
|
||||
|
||||
# Save updated index
|
||||
await self._save_index()
|
||||
|
||||
async def delete_chunk(self, chunk_id: str) -> None:
|
||||
if chunk_id not in self.chunk_ids:
|
||||
return
|
||||
|
||||
async with self.chunk_id_lock:
|
||||
index = self.chunk_ids.index(chunk_id)
|
||||
self.index.remove_ids(np.array([index]))
|
||||
|
||||
new_chunk_by_index = {}
|
||||
for idx, chunk in self.chunk_by_index.items():
|
||||
# Shift all chunks after the removed chunk to the left
|
||||
if idx > index:
|
||||
new_chunk_by_index[idx - 1] = chunk
|
||||
else:
|
||||
new_chunk_by_index[idx] = chunk
|
||||
self.chunk_by_index = new_chunk_by_index
|
||||
self.chunk_ids.pop(index)
|
||||
|
||||
await self._save_index()
|
||||
|
||||
async def query_vector(
|
||||
self,
|
||||
embedding: NDArray,
|
||||
|
@ -260,3 +288,8 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPr
|
|||
raise ValueError(f"Vector DB {vector_db_id} not found")
|
||||
|
||||
return await index.query_chunks(query, params)
|
||||
|
||||
async def _delete_openai_chunk_from_vector_store(self, store_id: str, chunk_id: str) -> None:
|
||||
"""Delete a chunk from a faiss index"""
|
||||
faiss_index = self.cache[store_id].index
|
||||
await faiss_index.delete_chunk(chunk_id)
|
||||
|
|
|
@ -425,6 +425,35 @@ class SQLiteVecIndex(EmbeddingIndex):
|
|||
|
||||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||
|
||||
async def delete_chunk(self, chunk_id: str) -> None:
|
||||
"""Remove a chunk from the SQLite vector store."""
|
||||
|
||||
def _delete_chunk():
|
||||
connection = _create_sqlite_connection(self.db_path)
|
||||
cur = connection.cursor()
|
||||
try:
|
||||
cur.execute("BEGIN TRANSACTION")
|
||||
|
||||
# Delete from metadata table
|
||||
cur.execute(f"DELETE FROM {self.metadata_table} WHERE id = ?", (chunk_id,))
|
||||
|
||||
# Delete from vector table
|
||||
cur.execute(f"DELETE FROM {self.vector_table} WHERE id = ?", (chunk_id,))
|
||||
|
||||
# Delete from FTS table
|
||||
cur.execute(f"DELETE FROM {self.fts_table} WHERE id = ?", (chunk_id,))
|
||||
|
||||
connection.commit()
|
||||
except Exception as e:
|
||||
connection.rollback()
|
||||
logger.error(f"Error deleting chunk {chunk_id}: {e}")
|
||||
raise
|
||||
finally:
|
||||
cur.close()
|
||||
connection.close()
|
||||
|
||||
await asyncio.to_thread(_delete_chunk)
|
||||
|
||||
|
||||
class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate):
|
||||
"""
|
||||
|
@ -520,3 +549,12 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc
|
|||
if not index:
|
||||
raise ValueError(f"Vector DB {vector_db_id} not found")
|
||||
return await index.query_chunks(query, params)
|
||||
|
||||
async def _delete_openai_chunk_from_vector_store(self, store_id: str, chunk_id: str) -> None:
|
||||
"""Delete a chunk from a sqlite_vec index."""
|
||||
index = await self._get_and_cache_vector_db_index(store_id)
|
||||
if not index:
|
||||
raise ValueError(f"Vector DB {store_id} not found")
|
||||
|
||||
# Use the index's delete_chunk method
|
||||
await index.index.delete_chunk(chunk_id)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue