This commit is contained in:
Derek Higgins 2025-07-24 16:09:59 -07:00 committed by GitHub
commit 754fb32c59
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 146 additions and 12 deletions

View file

@ -237,7 +237,10 @@ API responses, specify the adapter here.
def remote_provider_spec(
api: Api, adapter: AdapterSpec, api_dependencies: list[Api] | None = None
api: Api,
adapter: AdapterSpec,
api_dependencies: list[Api] | None = None,
optional_api_dependencies: list[Api] | None = None,
) -> RemoteProviderSpec:
return RemoteProviderSpec(
api=api,
@ -245,6 +248,7 @@ def remote_provider_spec(
config_class=adapter.config_class,
adapter=adapter,
api_dependencies=api_dependencies or [],
optional_api_dependencies=optional_api_dependencies or [],
)

View file

@ -55,6 +55,11 @@ class FaissIndex(EmbeddingIndex):
self.kvstore = kvstore
self.bank_id = bank_id
# A list of chunk id's in the same order as they are in the index,
# must be updated when chunks are added or removed
self.chunk_id_lock = asyncio.Lock()
self.chunk_ids: list[Any] = []
@classmethod
async def create(cls, dimension: int, kvstore: KVStore | None = None, bank_id: str | None = None):
instance = cls(dimension, kvstore, bank_id)
@ -75,6 +80,7 @@ class FaissIndex(EmbeddingIndex):
buffer = io.BytesIO(base64.b64decode(data["faiss_index"]))
try:
self.index = faiss.deserialize_index(np.load(buffer, allow_pickle=False))
self.chunk_ids = [chunk.chunk_id for chunk in self.chunk_by_index.values()]
except Exception as e:
logger.debug(e, exc_info=True)
raise ValueError(
@ -114,11 +120,33 @@ class FaissIndex(EmbeddingIndex):
for i, chunk in enumerate(chunks):
self.chunk_by_index[indexlen + i] = chunk
self.index.add(np.array(embeddings).astype(np.float32))
async with self.chunk_id_lock:
self.index.add(np.array(embeddings).astype(np.float32))
self.chunk_ids.extend([chunk.chunk_id for chunk in chunks])
# Save updated index
await self._save_index()
async def delete_chunk(self, chunk_id: str) -> None:
if chunk_id not in self.chunk_ids:
return
async with self.chunk_id_lock:
index = self.chunk_ids.index(chunk_id)
self.index.remove_ids(np.array([index]))
new_chunk_by_index = {}
for idx, chunk in self.chunk_by_index.items():
# Shift all chunks after the removed chunk to the left
if idx > index:
new_chunk_by_index[idx - 1] = chunk
else:
new_chunk_by_index[idx] = chunk
self.chunk_by_index = new_chunk_by_index
self.chunk_ids.pop(index)
await self._save_index()
async def query_vector(
self,
embedding: NDArray,
@ -260,3 +288,8 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPr
raise ValueError(f"Vector DB {vector_db_id} not found")
return await index.query_chunks(query, params)
async def _delete_openai_chunk_from_vector_store(self, store_id: str, chunk_id: str) -> None:
"""Delete a chunk from a faiss index"""
faiss_index = self.cache[store_id].index
await faiss_index.delete_chunk(chunk_id)

View file

@ -425,6 +425,35 @@ class SQLiteVecIndex(EmbeddingIndex):
return QueryChunksResponse(chunks=chunks, scores=scores)
async def delete_chunk(self, chunk_id: str) -> None:
"""Remove a chunk from the SQLite vector store."""
def _delete_chunk():
connection = _create_sqlite_connection(self.db_path)
cur = connection.cursor()
try:
cur.execute("BEGIN TRANSACTION")
# Delete from metadata table
cur.execute(f"DELETE FROM {self.metadata_table} WHERE id = ?", (chunk_id,))
# Delete from vector table
cur.execute(f"DELETE FROM {self.vector_table} WHERE id = ?", (chunk_id,))
# Delete from FTS table
cur.execute(f"DELETE FROM {self.fts_table} WHERE id = ?", (chunk_id,))
connection.commit()
except Exception as e:
connection.rollback()
logger.error(f"Error deleting chunk {chunk_id}: {e}")
raise
finally:
cur.close()
connection.close()
await asyncio.to_thread(_delete_chunk)
class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate):
"""
@ -520,3 +549,12 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc
if not index:
raise ValueError(f"Vector DB {vector_db_id} not found")
return await index.query_chunks(query, params)
async def _delete_openai_chunk_from_vector_store(self, store_id: str, chunk_id: str) -> None:
"""Delete a chunk from a sqlite_vec index."""
index = await self._get_and_cache_vector_db_index(store_id)
if not index:
raise ValueError(f"Vector DB {store_id} not found")
# Use the index's delete_chunk method
await index.index.delete_chunk(chunk_id)

View file

@ -410,6 +410,7 @@ See [PGVector's documentation](https://github.com/pgvector/pgvector) for more de
""",
),
api_dependencies=[Api.inference],
optional_api_dependencies=[Api.files],
),
remote_provider_spec(
Api.vector_io,

View file

@ -112,6 +112,9 @@ class ChromaIndex(EmbeddingIndex):
) -> QueryChunksResponse:
raise NotImplementedError("Keyword search is not supported in Chroma")
async def delete_chunk(self, chunk_id: str) -> None:
raise NotImplementedError("delete_chunk is not supported in Chroma")
async def query_hybrid(
self,
embedding: NDArray,
@ -208,3 +211,6 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
index = VectorDBWithIndex(vector_db, ChromaIndex(self.client, collection), self.inference_api)
self.cache[vector_db_id] = index
return index
async def _delete_openai_chunk_from_vector_store(self, store_id: str, chunk_id: str) -> None:
raise NotImplementedError("OpenAI Vector Stores API is not supported in Chroma")

View file

@ -247,6 +247,16 @@ class MilvusIndex(EmbeddingIndex):
) -> QueryChunksResponse:
raise NotImplementedError("Hybrid search is not supported in Milvus")
async def delete_chunk(self, chunk_id: str) -> None:
"""Remove a chunk from the Milvus collection."""
try:
await asyncio.to_thread(
self.client.delete, collection_name=self.collection_name, filter=f'chunk_id == "{chunk_id}"'
)
except Exception as e:
logger.error(f"Error deleting chunk {chunk_id} from Milvus collection {self.collection_name}: {e}")
raise
class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate):
def __init__(
@ -369,3 +379,12 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
)
return await index.query_chunks(query, params)
async def _delete_openai_chunk_from_vector_store(self, store_id: str, chunk_id: str) -> None:
"""Delete a chunk from a milvus vector store."""
index = await self._get_and_cache_vector_db_index(store_id)
if not index:
raise ValueError(f"Vector DB {store_id} not found")
# Use the index's delete_chunk method
await index.index.delete_chunk(chunk_id)

View file

@ -12,6 +12,6 @@ from .config import PGVectorVectorIOConfig
async def get_adapter_impl(config: PGVectorVectorIOConfig, deps: dict[Api, ProviderSpec]):
from .pgvector import PGVectorVectorIOAdapter
impl = PGVectorVectorIOAdapter(config, deps[Api.inference])
impl = PGVectorVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files, None))
await impl.initialize()
return impl

View file

@ -99,7 +99,7 @@ class PGVectorIndex(EmbeddingIndex):
for i, chunk in enumerate(chunks):
values.append(
(
f"{chunk.metadata['document_id']}:chunk-{i}",
f"{chunk.chunk_id}",
Json(chunk.model_dump()),
embeddings[i].tolist(),
)
@ -159,6 +159,11 @@ class PGVectorIndex(EmbeddingIndex):
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
cur.execute(f"DROP TABLE IF EXISTS {self.table_name}")
async def delete_chunk(self, chunk_id: str) -> None:
"""Remove a chunk from the PostgreSQL table."""
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
cur.execute(f"DELETE FROM {self.table_name} WHERE id = %s", (chunk_id,))
class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate):
def __init__(
@ -265,3 +270,12 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoco
index = PGVectorIndex(vector_db, vector_db.embedding_dimension, self.conn)
self.cache[vector_db_id] = VectorDBWithIndex(vector_db, index, self.inference_api)
return self.cache[vector_db_id]
async def _delete_openai_chunk_from_vector_store(self, store_id: str, chunk_id: str) -> None:
"""Delete a chunk from a PostgreSQL vector store."""
index = await self._get_and_cache_vector_db_index(store_id)
if not index:
raise ValueError(f"Vector DB {store_id} not found")
# Use the index's delete_chunk method
await index.index.delete_chunk(chunk_id)

View file

@ -82,6 +82,9 @@ class QdrantIndex(EmbeddingIndex):
await self.client.upsert(collection_name=self.collection_name, points=points)
async def delete_chunk(self, chunk_id: str) -> None:
raise NotImplementedError("delete_chunk is not supported in qdrant")
async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
results = (
await self.client.query_points(
@ -307,3 +310,6 @@ class QdrantVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
file_id: str,
) -> VectorStoreFileObject:
raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant")
async def _delete_openai_chunk_from_vector_store(self, store_id: str, chunk_id: str) -> None:
raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant")

View file

@ -66,6 +66,9 @@ class WeaviateIndex(EmbeddingIndex):
# TODO: make this async friendly
collection.data.insert_many(data_objects)
async def delete_chunk(self, chunk_id: str) -> None:
raise NotImplementedError("delete_chunk is not supported in Chroma")
async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
collection = self.client.collections.get(self.collection_name)
@ -264,3 +267,6 @@ class WeaviateVectorIOAdapter(
async def _delete_openai_vector_store_file_from_storage(self, store_id: str, file_id: str) -> None:
raise NotImplementedError("OpenAI Vector Stores API is not supported in Weaviate")
async def _delete_openai_chunk_from_vector_store(self, store_id: str, chunk_id: str) -> None:
raise NotImplementedError("OpenAI Vector Stores API is not supported in Weaviate")

View file

@ -152,6 +152,11 @@ class OpenAIVectorStoreMixin(ABC):
"""Load existing OpenAI vector stores into the in-memory cache."""
self.openai_vector_stores = await self._load_openai_vector_stores()
@abstractmethod
async def _delete_openai_chunk_from_vector_store(self, store_id: str, chunk_id: str) -> None:
"""Delete a chunk from a vector store."""
pass
@abstractmethod
async def register_vector_db(self, vector_db: VectorDB) -> None:
"""Register a vector database (provider-specific implementation)."""
@ -763,17 +768,17 @@ class OpenAIVectorStoreMixin(ABC):
if vector_store_id not in self.openai_vector_stores:
raise ValueError(f"Vector store {vector_store_id} not found")
dict_chunks = await self._load_openai_vector_store_file_contents(vector_store_id, file_id)
chunks = [Chunk.model_validate(c) for c in dict_chunks]
for c in chunks:
if c.chunk_id:
await self._delete_openai_chunk_from_vector_store(vector_store_id, str(c.chunk_id))
store_info = self.openai_vector_stores[vector_store_id].copy()
file = await self.openai_retrieve_vector_store_file(vector_store_id, file_id)
await self._delete_openai_vector_store_file_from_storage(vector_store_id, file_id)
# TODO: We need to actually delete the embeddings from the underlying vector store...
# Also uncomment the corresponding integration test marked as xfail
#
# test_openai_vector_store_delete_file_removes_from_vector_store in
# tests/integration/vector_io/test_openai_vector_stores.py
# Update in-memory cache
store_info["file_ids"].remove(file_id)
store_info["file_counts"][file.status] -= 1

View file

@ -231,6 +231,10 @@ class EmbeddingIndex(ABC):
async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray):
raise NotImplementedError()
@abstractmethod
async def delete_chunk(self, chunk_id: str):
raise NotImplementedError()
@abstractmethod
async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
raise NotImplementedError()

View file

@ -735,8 +735,6 @@ def test_openai_vector_store_delete_file(compat_client_with_empty_stores, client
assert updated_vector_store.file_counts.in_progress == 0
# TODO: Remove this xfail once we have a way to remove embeddings from vector store
@pytest.mark.xfail(reason="Vector Store Files delete doesn't remove embeddings from vector store", strict=True)
def test_openai_vector_store_delete_file_removes_from_vector_store(compat_client_with_empty_stores, client_with_models):
"""Test OpenAI vector store delete file removes from vector store."""
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)