diff --git a/llama_stack/providers/inline/vector_io/faiss/faiss.py b/llama_stack/providers/inline/vector_io/faiss/faiss.py index a57b4a4ee..f7052c6e5 100644 --- a/llama_stack/providers/inline/vector_io/faiss/faiss.py +++ b/llama_stack/providers/inline/vector_io/faiss/faiss.py @@ -55,6 +55,11 @@ class FaissIndex(EmbeddingIndex): self.kvstore = kvstore self.bank_id = bank_id + # A list of chunk id's in the same order as they are in the index, + # must be updated when chunks are added or removed + self.chunk_id_lock = asyncio.Lock() + self.chunk_ids: list[Any] = [] + @classmethod async def create(cls, dimension: int, kvstore: KVStore | None = None, bank_id: str | None = None): instance = cls(dimension, kvstore, bank_id) @@ -75,6 +80,7 @@ class FaissIndex(EmbeddingIndex): buffer = io.BytesIO(base64.b64decode(data["faiss_index"])) try: self.index = faiss.deserialize_index(np.load(buffer, allow_pickle=False)) + self.chunk_ids = [chunk.chunk_id for chunk in self.chunk_by_index.values()] except Exception as e: logger.debug(e, exc_info=True) raise ValueError( @@ -114,11 +120,33 @@ class FaissIndex(EmbeddingIndex): for i, chunk in enumerate(chunks): self.chunk_by_index[indexlen + i] = chunk - self.index.add(np.array(embeddings).astype(np.float32)) + async with self.chunk_id_lock: + self.index.add(np.array(embeddings).astype(np.float32)) + self.chunk_ids.extend([chunk.chunk_id for chunk in chunks]) # Save updated index await self._save_index() + async def delete_chunk(self, chunk_id: str) -> None: + if chunk_id not in self.chunk_ids: + return + + async with self.chunk_id_lock: + index = self.chunk_ids.index(chunk_id) + self.index.remove_ids(np.array([index])) + + new_chunk_by_index = {} + for idx, chunk in self.chunk_by_index.items(): + # Shift all chunks after the removed chunk to the left + if idx > index: + new_chunk_by_index[idx - 1] = chunk + else: + new_chunk_by_index[idx] = chunk + self.chunk_by_index = new_chunk_by_index + self.chunk_ids.pop(index) + + await self._save_index() + async def query_vector( self, embedding: NDArray, @@ -260,3 +288,8 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPr raise ValueError(f"Vector DB {vector_db_id} not found") return await index.query_chunks(query, params) + + async def _delete_openai_chunk_from_vector_store(self, store_id: str, chunk_id: str) -> None: + """Delete a chunk from a faiss index""" + faiss_index = self.cache[store_id].index + await faiss_index.delete_chunk(chunk_id) diff --git a/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py b/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py index f2598cc7c..a8331e65a 100644 --- a/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py +++ b/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py @@ -520,3 +520,7 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc if not index: raise ValueError(f"Vector DB {vector_db_id} not found") return await index.query_chunks(query, params) + + async def _delete_openai_chunk_from_vector_store(self, store_id: str, chunk_id: str) -> None: + """Delete a chunk from a sqlite_vec index.""" + pass # TODO diff --git a/llama_stack/providers/remote/vector_io/milvus/milvus.py b/llama_stack/providers/remote/vector_io/milvus/milvus.py index dc4852821..d41c7eb8d 100644 --- a/llama_stack/providers/remote/vector_io/milvus/milvus.py +++ b/llama_stack/providers/remote/vector_io/milvus/milvus.py @@ -369,3 +369,7 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP ) return await index.query_chunks(query, params) + + async def _delete_openai_chunk_from_vector_store(self, store_id: str, chunk_id: str) -> None: + """Delete a chunk from a milvus vector store.""" + pass # TODO diff --git a/llama_stack/providers/utils/memory/openai_vector_store_mixin.py b/llama_stack/providers/utils/memory/openai_vector_store_mixin.py index f178e9299..294da69a1 100644 --- a/llama_stack/providers/utils/memory/openai_vector_store_mixin.py +++ b/llama_stack/providers/utils/memory/openai_vector_store_mixin.py @@ -152,6 +152,11 @@ class OpenAIVectorStoreMixin(ABC): """Load existing OpenAI vector stores into the in-memory cache.""" self.openai_vector_stores = await self._load_openai_vector_stores() + @abstractmethod + async def _delete_openai_chunk_from_vector_store(self, store_id: str, chunk_id: str) -> None: + """Delete a chunk from a vector store.""" + pass + @abstractmethod async def register_vector_db(self, vector_db: VectorDB) -> None: """Register a vector database (provider-specific implementation).""" @@ -763,6 +768,12 @@ class OpenAIVectorStoreMixin(ABC): if vector_store_id not in self.openai_vector_stores: raise ValueError(f"Vector store {vector_store_id} not found") + dict_chunks = await self._load_openai_vector_store_file_contents(vector_store_id, file_id) + chunks = [Chunk.model_validate(c) for c in dict_chunks] + for c in chunks: + if c.chunk_id: + await self._delete_openai_chunk_from_vector_store(vector_store_id, str(c.chunk_id)) + store_info = self.openai_vector_stores[vector_store_id].copy() file = await self.openai_retrieve_vector_store_file(vector_store_id, file_id) diff --git a/llama_stack/providers/utils/memory/vector_store.py b/llama_stack/providers/utils/memory/vector_store.py index f892d33c6..4a8749cba 100644 --- a/llama_stack/providers/utils/memory/vector_store.py +++ b/llama_stack/providers/utils/memory/vector_store.py @@ -231,6 +231,10 @@ class EmbeddingIndex(ABC): async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray): raise NotImplementedError() + @abstractmethod + async def delete_chunk(self, chunk_id: str): + raise NotImplementedError() + @abstractmethod async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse: raise NotImplementedError()