diff --git a/llama_stack/providers/datatypes.py b/llama_stack/providers/datatypes.py index 055bf5232..faf7ff18c 100644 --- a/llama_stack/providers/datatypes.py +++ b/llama_stack/providers/datatypes.py @@ -241,7 +241,10 @@ API responses, specify the adapter here. def remote_provider_spec( - api: Api, adapter: AdapterSpec, api_dependencies: list[Api] | None = None + api: Api, + adapter: AdapterSpec, + api_dependencies: list[Api] | None = None, + optional_api_dependencies: list[Api] | None = None, ) -> RemoteProviderSpec: return RemoteProviderSpec( api=api, @@ -250,6 +253,7 @@ def remote_provider_spec( module=adapter.module, adapter=adapter, api_dependencies=api_dependencies or [], + optional_api_dependencies=optional_api_dependencies or [], ) diff --git a/llama_stack/providers/inline/vector_io/faiss/faiss.py b/llama_stack/providers/inline/vector_io/faiss/faiss.py index a57b4a4ee..edee4649d 100644 --- a/llama_stack/providers/inline/vector_io/faiss/faiss.py +++ b/llama_stack/providers/inline/vector_io/faiss/faiss.py @@ -55,6 +55,11 @@ class FaissIndex(EmbeddingIndex): self.kvstore = kvstore self.bank_id = bank_id + # A list of chunk id's in the same order as they are in the index, + # must be updated when chunks are added or removed + self.chunk_id_lock = asyncio.Lock() + self.chunk_ids: list[Any] = [] + @classmethod async def create(cls, dimension: int, kvstore: KVStore | None = None, bank_id: str | None = None): instance = cls(dimension, kvstore, bank_id) @@ -75,6 +80,7 @@ class FaissIndex(EmbeddingIndex): buffer = io.BytesIO(base64.b64decode(data["faiss_index"])) try: self.index = faiss.deserialize_index(np.load(buffer, allow_pickle=False)) + self.chunk_ids = [chunk.chunk_id for chunk in self.chunk_by_index.values()] except Exception as e: logger.debug(e, exc_info=True) raise ValueError( @@ -114,11 +120,33 @@ class FaissIndex(EmbeddingIndex): for i, chunk in enumerate(chunks): self.chunk_by_index[indexlen + i] = chunk - self.index.add(np.array(embeddings).astype(np.float32)) + async with self.chunk_id_lock: + self.index.add(np.array(embeddings).astype(np.float32)) + self.chunk_ids.extend([chunk.chunk_id for chunk in chunks]) # Save updated index await self._save_index() + async def delete_chunk(self, chunk_id: str) -> None: + if chunk_id not in self.chunk_ids: + return + + async with self.chunk_id_lock: + index = self.chunk_ids.index(chunk_id) + self.index.remove_ids(np.array([index])) + + new_chunk_by_index = {} + for idx, chunk in self.chunk_by_index.items(): + # Shift all chunks after the removed chunk to the left + if idx > index: + new_chunk_by_index[idx - 1] = chunk + else: + new_chunk_by_index[idx] = chunk + self.chunk_by_index = new_chunk_by_index + self.chunk_ids.pop(index) + + await self._save_index() + async def query_vector( self, embedding: NDArray, @@ -260,3 +288,9 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPr raise ValueError(f"Vector DB {vector_db_id} not found") return await index.query_chunks(query, params) + + async def delete_chunks(self, store_id: str, chunk_ids: list[str]) -> None: + """Delete a chunk from a faiss index""" + faiss_index = self.cache[store_id].index + for chunk_id in chunk_ids: + await faiss_index.delete_chunk(chunk_id) diff --git a/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py b/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py index f2598cc7c..cfa4e2263 100644 --- a/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py +++ b/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py @@ -425,6 +425,35 @@ class SQLiteVecIndex(EmbeddingIndex): return QueryChunksResponse(chunks=chunks, scores=scores) + async def delete_chunk(self, chunk_id: str) -> None: + """Remove a chunk from the SQLite vector store.""" + + def _delete_chunk(): + connection = _create_sqlite_connection(self.db_path) + cur = connection.cursor() + try: + cur.execute("BEGIN TRANSACTION") + + # Delete from metadata table + cur.execute(f"DELETE FROM {self.metadata_table} WHERE id = ?", (chunk_id,)) + + # Delete from vector table + cur.execute(f"DELETE FROM {self.vector_table} WHERE id = ?", (chunk_id,)) + + # Delete from FTS table + cur.execute(f"DELETE FROM {self.fts_table} WHERE id = ?", (chunk_id,)) + + connection.commit() + except Exception as e: + connection.rollback() + logger.error(f"Error deleting chunk {chunk_id}: {e}") + raise + finally: + cur.close() + connection.close() + + await asyncio.to_thread(_delete_chunk) + class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate): """ @@ -520,3 +549,13 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc if not index: raise ValueError(f"Vector DB {vector_db_id} not found") return await index.query_chunks(query, params) + + async def delete_chunks(self, store_id: str, chunk_ids: list[str]) -> None: + """Delete a chunk from a sqlite_vec index.""" + index = await self._get_and_cache_vector_db_index(store_id) + if not index: + raise ValueError(f"Vector DB {store_id} not found") + + for chunk_id in chunk_ids: + # Use the index's delete_chunk method + await index.index.delete_chunk(chunk_id) diff --git a/llama_stack/providers/registry/vector_io.py b/llama_stack/providers/registry/vector_io.py index e391341b4..063b382df 100644 --- a/llama_stack/providers/registry/vector_io.py +++ b/llama_stack/providers/registry/vector_io.py @@ -410,6 +410,7 @@ See [PGVector's documentation](https://github.com/pgvector/pgvector) for more de """, ), api_dependencies=[Api.inference], + optional_api_dependencies=[Api.files], ), remote_provider_spec( Api.vector_io, diff --git a/llama_stack/providers/remote/vector_io/chroma/chroma.py b/llama_stack/providers/remote/vector_io/chroma/chroma.py index c16661b67..26aeaedfb 100644 --- a/llama_stack/providers/remote/vector_io/chroma/chroma.py +++ b/llama_stack/providers/remote/vector_io/chroma/chroma.py @@ -115,6 +115,9 @@ class ChromaIndex(EmbeddingIndex): ) -> QueryChunksResponse: raise NotImplementedError("Keyword search is not supported in Chroma") + async def delete_chunk(self, chunk_id: str) -> None: + raise NotImplementedError("delete_chunk is not supported in Chroma") + async def query_hybrid( self, embedding: NDArray, @@ -223,3 +226,6 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP index = VectorDBWithIndex(vector_db, ChromaIndex(self.client, collection), self.inference_api) self.cache[vector_db_id] = index return index + + async def delete_chunks(self, store_id: str, chunk_ids: list[str]) -> None: + raise NotImplementedError("OpenAI Vector Stores API is not supported in Chroma") diff --git a/llama_stack/providers/remote/vector_io/milvus/milvus.py b/llama_stack/providers/remote/vector_io/milvus/milvus.py index dc4852821..f1652a80e 100644 --- a/llama_stack/providers/remote/vector_io/milvus/milvus.py +++ b/llama_stack/providers/remote/vector_io/milvus/milvus.py @@ -247,6 +247,16 @@ class MilvusIndex(EmbeddingIndex): ) -> QueryChunksResponse: raise NotImplementedError("Hybrid search is not supported in Milvus") + async def delete_chunk(self, chunk_id: str) -> None: + """Remove a chunk from the Milvus collection.""" + try: + await asyncio.to_thread( + self.client.delete, collection_name=self.collection_name, filter=f'chunk_id == "{chunk_id}"' + ) + except Exception as e: + logger.error(f"Error deleting chunk {chunk_id} from Milvus collection {self.collection_name}: {e}") + raise + class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate): def __init__( @@ -369,3 +379,13 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP ) return await index.query_chunks(query, params) + + async def delete_chunks(self, store_id: str, chunk_ids: list[str]) -> None: + """Delete a chunk from a milvus vector store.""" + index = await self._get_and_cache_vector_db_index(store_id) + if not index: + raise ValueError(f"Vector DB {store_id} not found") + + for chunk_id in chunk_ids: + # Use the index's delete_chunk method + await index.index.delete_chunk(chunk_id) diff --git a/llama_stack/providers/remote/vector_io/pgvector/__init__.py b/llama_stack/providers/remote/vector_io/pgvector/__init__.py index 9f528db74..59eef4c81 100644 --- a/llama_stack/providers/remote/vector_io/pgvector/__init__.py +++ b/llama_stack/providers/remote/vector_io/pgvector/__init__.py @@ -12,6 +12,6 @@ from .config import PGVectorVectorIOConfig async def get_adapter_impl(config: PGVectorVectorIOConfig, deps: dict[Api, ProviderSpec]): from .pgvector import PGVectorVectorIOAdapter - impl = PGVectorVectorIOAdapter(config, deps[Api.inference]) + impl = PGVectorVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files, None)) await impl.initialize() return impl diff --git a/llama_stack/providers/remote/vector_io/pgvector/pgvector.py b/llama_stack/providers/remote/vector_io/pgvector/pgvector.py index 3aeb3f30d..643c27328 100644 --- a/llama_stack/providers/remote/vector_io/pgvector/pgvector.py +++ b/llama_stack/providers/remote/vector_io/pgvector/pgvector.py @@ -99,7 +99,7 @@ class PGVectorIndex(EmbeddingIndex): for i, chunk in enumerate(chunks): values.append( ( - f"{chunk.metadata['document_id']}:chunk-{i}", + f"{chunk.chunk_id}", Json(chunk.model_dump()), embeddings[i].tolist(), ) @@ -159,6 +159,11 @@ class PGVectorIndex(EmbeddingIndex): with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur: cur.execute(f"DROP TABLE IF EXISTS {self.table_name}") + async def delete_chunk(self, chunk_id: str) -> None: + """Remove a chunk from the PostgreSQL table.""" + with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur: + cur.execute(f"DELETE FROM {self.table_name} WHERE id = %s", (chunk_id,)) + class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate): def __init__( @@ -265,3 +270,13 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoco index = PGVectorIndex(vector_db, vector_db.embedding_dimension, self.conn) self.cache[vector_db_id] = VectorDBWithIndex(vector_db, index, self.inference_api) return self.cache[vector_db_id] + + async def delete_chunks(self, store_id: str, chunk_ids: list[str]) -> None: + """Delete a chunk from a PostgreSQL vector store.""" + index = await self._get_and_cache_vector_db_index(store_id) + if not index: + raise ValueError(f"Vector DB {store_id} not found") + + for chunk_id in chunk_ids: + # Use the index's delete_chunk method + await index.index.delete_chunk(chunk_id) diff --git a/llama_stack/providers/remote/vector_io/qdrant/qdrant.py b/llama_stack/providers/remote/vector_io/qdrant/qdrant.py index 5bdea0ce8..3df3da27f 100644 --- a/llama_stack/providers/remote/vector_io/qdrant/qdrant.py +++ b/llama_stack/providers/remote/vector_io/qdrant/qdrant.py @@ -82,6 +82,9 @@ class QdrantIndex(EmbeddingIndex): await self.client.upsert(collection_name=self.collection_name, points=points) + async def delete_chunk(self, chunk_id: str) -> None: + raise NotImplementedError("delete_chunk is not supported in qdrant") + async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse: results = ( await self.client.query_points( @@ -307,3 +310,6 @@ class QdrantVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): file_id: str, ) -> VectorStoreFileObject: raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant") + + async def delete_chunks(self, store_id: str, chunk_ids: list[str]) -> None: + raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant") diff --git a/llama_stack/providers/remote/vector_io/weaviate/weaviate.py b/llama_stack/providers/remote/vector_io/weaviate/weaviate.py index 35bb40454..543835e20 100644 --- a/llama_stack/providers/remote/vector_io/weaviate/weaviate.py +++ b/llama_stack/providers/remote/vector_io/weaviate/weaviate.py @@ -66,6 +66,9 @@ class WeaviateIndex(EmbeddingIndex): # TODO: make this async friendly collection.data.insert_many(data_objects) + async def delete_chunk(self, chunk_id: str) -> None: + raise NotImplementedError("delete_chunk is not supported in Chroma") + async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse: collection = self.client.collections.get(self.collection_name) @@ -264,3 +267,6 @@ class WeaviateVectorIOAdapter( async def _delete_openai_vector_store_file_from_storage(self, store_id: str, file_id: str) -> None: raise NotImplementedError("OpenAI Vector Stores API is not supported in Weaviate") + + async def delete_chunks(self, store_id: str, chunk_ids: list[str]) -> None: + raise NotImplementedError("OpenAI Vector Stores API is not supported in Weaviate") diff --git a/llama_stack/providers/utils/memory/openai_vector_store_mixin.py b/llama_stack/providers/utils/memory/openai_vector_store_mixin.py index f178e9299..ee69d7c52 100644 --- a/llama_stack/providers/utils/memory/openai_vector_store_mixin.py +++ b/llama_stack/providers/utils/memory/openai_vector_store_mixin.py @@ -152,6 +152,11 @@ class OpenAIVectorStoreMixin(ABC): """Load existing OpenAI vector stores into the in-memory cache.""" self.openai_vector_stores = await self._load_openai_vector_stores() + @abstractmethod + async def delete_chunks(self, store_id: str, chunk_ids: list[str]) -> None: + """Delete a chunk from a vector store.""" + pass + @abstractmethod async def register_vector_db(self, vector_db: VectorDB) -> None: """Register a vector database (provider-specific implementation).""" @@ -763,17 +768,15 @@ class OpenAIVectorStoreMixin(ABC): if vector_store_id not in self.openai_vector_stores: raise ValueError(f"Vector store {vector_store_id} not found") + dict_chunks = await self._load_openai_vector_store_file_contents(vector_store_id, file_id) + chunks = [Chunk.model_validate(c) for c in dict_chunks] + await self.delete_chunks(vector_store_id, [str(c.chunk_id) for c in chunks if c.chunk_id]) + store_info = self.openai_vector_stores[vector_store_id].copy() file = await self.openai_retrieve_vector_store_file(vector_store_id, file_id) await self._delete_openai_vector_store_file_from_storage(vector_store_id, file_id) - # TODO: We need to actually delete the embeddings from the underlying vector store... - # Also uncomment the corresponding integration test marked as xfail - # - # test_openai_vector_store_delete_file_removes_from_vector_store in - # tests/integration/vector_io/test_openai_vector_stores.py - # Update in-memory cache store_info["file_ids"].remove(file_id) store_info["file_counts"][file.status] -= 1 diff --git a/llama_stack/providers/utils/memory/vector_store.py b/llama_stack/providers/utils/memory/vector_store.py index f892d33c6..4a8749cba 100644 --- a/llama_stack/providers/utils/memory/vector_store.py +++ b/llama_stack/providers/utils/memory/vector_store.py @@ -231,6 +231,10 @@ class EmbeddingIndex(ABC): async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray): raise NotImplementedError() + @abstractmethod + async def delete_chunk(self, chunk_id: str): + raise NotImplementedError() + @abstractmethod async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse: raise NotImplementedError() diff --git a/tests/integration/vector_io/test_openai_vector_stores.py b/tests/integration/vector_io/test_openai_vector_stores.py index 9771ab290..a34c5b410 100644 --- a/tests/integration/vector_io/test_openai_vector_stores.py +++ b/tests/integration/vector_io/test_openai_vector_stores.py @@ -723,8 +723,6 @@ def test_openai_vector_store_delete_file(compat_client_with_empty_stores, client assert updated_vector_store.file_counts.in_progress == 0 -# TODO: Remove this xfail once we have a way to remove embeddings from vector store -@pytest.mark.xfail(reason="Vector Store Files delete doesn't remove embeddings from vector store", strict=True) def test_openai_vector_store_delete_file_removes_from_vector_store(compat_client_with_empty_stores, client_with_models): """Test OpenAI vector store delete file removes from vector store.""" skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)