mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-27 14:38:49 +00:00
feat: implement chunk deletion for vector stores
Add support for deleting individual chunks from vector stores - Add abstract delete_chunk() method to EmbeddingIndex base class - Implement chunk deletion for Faiss provider with index tracking - Add chunk_ids list to maintain chunk order in Faiss index - Integrate chunk deletion into OpenAI vector store file deletion flow - Add placeholder implementations for SQLite and Milvus providers Closes: #2477 Signed-off-by: Derek Higgins <derekh@redhat.com>
This commit is contained in:
parent
cd8715d327
commit
3d4b32db0a
5 changed files with 57 additions and 1 deletions
|
@ -55,6 +55,11 @@ class FaissIndex(EmbeddingIndex):
|
||||||
self.kvstore = kvstore
|
self.kvstore = kvstore
|
||||||
self.bank_id = bank_id
|
self.bank_id = bank_id
|
||||||
|
|
||||||
|
# A list of chunk id's in the same order as they are in the index,
|
||||||
|
# must be updated when chunks are added or removed
|
||||||
|
self.chunk_id_lock = asyncio.Lock()
|
||||||
|
self.chunk_ids: list[Any] = []
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def create(cls, dimension: int, kvstore: KVStore | None = None, bank_id: str | None = None):
|
async def create(cls, dimension: int, kvstore: KVStore | None = None, bank_id: str | None = None):
|
||||||
instance = cls(dimension, kvstore, bank_id)
|
instance = cls(dimension, kvstore, bank_id)
|
||||||
|
@ -75,6 +80,7 @@ class FaissIndex(EmbeddingIndex):
|
||||||
buffer = io.BytesIO(base64.b64decode(data["faiss_index"]))
|
buffer = io.BytesIO(base64.b64decode(data["faiss_index"]))
|
||||||
try:
|
try:
|
||||||
self.index = faiss.deserialize_index(np.load(buffer, allow_pickle=False))
|
self.index = faiss.deserialize_index(np.load(buffer, allow_pickle=False))
|
||||||
|
self.chunk_ids = [chunk.chunk_id for chunk in self.chunk_by_index.values()]
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug(e, exc_info=True)
|
logger.debug(e, exc_info=True)
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
@ -114,11 +120,33 @@ class FaissIndex(EmbeddingIndex):
|
||||||
for i, chunk in enumerate(chunks):
|
for i, chunk in enumerate(chunks):
|
||||||
self.chunk_by_index[indexlen + i] = chunk
|
self.chunk_by_index[indexlen + i] = chunk
|
||||||
|
|
||||||
|
async with self.chunk_id_lock:
|
||||||
self.index.add(np.array(embeddings).astype(np.float32))
|
self.index.add(np.array(embeddings).astype(np.float32))
|
||||||
|
self.chunk_ids.extend([chunk.chunk_id for chunk in chunks])
|
||||||
|
|
||||||
# Save updated index
|
# Save updated index
|
||||||
await self._save_index()
|
await self._save_index()
|
||||||
|
|
||||||
|
async def delete_chunk(self, chunk_id: str) -> None:
|
||||||
|
if chunk_id not in self.chunk_ids:
|
||||||
|
return
|
||||||
|
|
||||||
|
async with self.chunk_id_lock:
|
||||||
|
index = self.chunk_ids.index(chunk_id)
|
||||||
|
self.index.remove_ids(np.array([index]))
|
||||||
|
|
||||||
|
new_chunk_by_index = {}
|
||||||
|
for idx, chunk in self.chunk_by_index.items():
|
||||||
|
# Shift all chunks after the removed chunk to the left
|
||||||
|
if idx > index:
|
||||||
|
new_chunk_by_index[idx - 1] = chunk
|
||||||
|
else:
|
||||||
|
new_chunk_by_index[idx] = chunk
|
||||||
|
self.chunk_by_index = new_chunk_by_index
|
||||||
|
self.chunk_ids.pop(index)
|
||||||
|
|
||||||
|
await self._save_index()
|
||||||
|
|
||||||
async def query_vector(
|
async def query_vector(
|
||||||
self,
|
self,
|
||||||
embedding: NDArray,
|
embedding: NDArray,
|
||||||
|
@ -260,3 +288,8 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPr
|
||||||
raise ValueError(f"Vector DB {vector_db_id} not found")
|
raise ValueError(f"Vector DB {vector_db_id} not found")
|
||||||
|
|
||||||
return await index.query_chunks(query, params)
|
return await index.query_chunks(query, params)
|
||||||
|
|
||||||
|
async def _delete_openai_chunk_from_vector_store(self, store_id: str, chunk_id: str) -> None:
|
||||||
|
"""Delete a chunk from a faiss index"""
|
||||||
|
faiss_index = self.cache[store_id].index
|
||||||
|
await faiss_index.delete_chunk(chunk_id)
|
||||||
|
|
|
@ -520,3 +520,7 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc
|
||||||
if not index:
|
if not index:
|
||||||
raise ValueError(f"Vector DB {vector_db_id} not found")
|
raise ValueError(f"Vector DB {vector_db_id} not found")
|
||||||
return await index.query_chunks(query, params)
|
return await index.query_chunks(query, params)
|
||||||
|
|
||||||
|
async def _delete_openai_chunk_from_vector_store(self, store_id: str, chunk_id: str) -> None:
|
||||||
|
"""Delete a chunk from a sqlite_vec index."""
|
||||||
|
pass # TODO
|
||||||
|
|
|
@ -369,3 +369,7 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
||||||
)
|
)
|
||||||
|
|
||||||
return await index.query_chunks(query, params)
|
return await index.query_chunks(query, params)
|
||||||
|
|
||||||
|
async def _delete_openai_chunk_from_vector_store(self, store_id: str, chunk_id: str) -> None:
|
||||||
|
"""Delete a chunk from a milvus vector store."""
|
||||||
|
pass # TODO
|
||||||
|
|
|
@ -152,6 +152,11 @@ class OpenAIVectorStoreMixin(ABC):
|
||||||
"""Load existing OpenAI vector stores into the in-memory cache."""
|
"""Load existing OpenAI vector stores into the in-memory cache."""
|
||||||
self.openai_vector_stores = await self._load_openai_vector_stores()
|
self.openai_vector_stores = await self._load_openai_vector_stores()
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def _delete_openai_chunk_from_vector_store(self, store_id: str, chunk_id: str) -> None:
|
||||||
|
"""Delete a chunk from a vector store."""
|
||||||
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def register_vector_db(self, vector_db: VectorDB) -> None:
|
async def register_vector_db(self, vector_db: VectorDB) -> None:
|
||||||
"""Register a vector database (provider-specific implementation)."""
|
"""Register a vector database (provider-specific implementation)."""
|
||||||
|
@ -763,6 +768,12 @@ class OpenAIVectorStoreMixin(ABC):
|
||||||
if vector_store_id not in self.openai_vector_stores:
|
if vector_store_id not in self.openai_vector_stores:
|
||||||
raise ValueError(f"Vector store {vector_store_id} not found")
|
raise ValueError(f"Vector store {vector_store_id} not found")
|
||||||
|
|
||||||
|
dict_chunks = await self._load_openai_vector_store_file_contents(vector_store_id, file_id)
|
||||||
|
chunks = [Chunk.model_validate(c) for c in dict_chunks]
|
||||||
|
for c in chunks:
|
||||||
|
if c.chunk_id:
|
||||||
|
await self._delete_openai_chunk_from_vector_store(vector_store_id, str(c.chunk_id))
|
||||||
|
|
||||||
store_info = self.openai_vector_stores[vector_store_id].copy()
|
store_info = self.openai_vector_stores[vector_store_id].copy()
|
||||||
|
|
||||||
file = await self.openai_retrieve_vector_store_file(vector_store_id, file_id)
|
file = await self.openai_retrieve_vector_store_file(vector_store_id, file_id)
|
||||||
|
|
|
@ -231,6 +231,10 @@ class EmbeddingIndex(ABC):
|
||||||
async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray):
|
async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray):
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def delete_chunk(self, chunk_id: str):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue