mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-15 06:00:48 +00:00
fix chunk deletion
This commit is contained in:
parent
c0be74b93e
commit
88cfab2768
10 changed files with 102 additions and 49 deletions
|
@ -33,6 +33,7 @@ from llama_stack.providers.utils.kvstore import kvstore_impl
|
|||
from llama_stack.providers.utils.kvstore.api import KVStore
|
||||
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
ChunkForDeletion,
|
||||
EmbeddingIndex,
|
||||
VectorDBWithIndex,
|
||||
)
|
||||
|
@ -128,11 +129,12 @@ class FaissIndex(EmbeddingIndex):
|
|||
# Save updated index
|
||||
await self._save_index()
|
||||
|
||||
async def delete_chunk(self, chunk_id: str) -> None:
|
||||
if chunk_id not in self.chunk_ids:
|
||||
async def delete_chunks(self, chunks_for_deletion: list[ChunkForDeletion]) -> None:
|
||||
chunk_ids = [c.chunk_id for c in chunks_for_deletion]
|
||||
if not set(chunk_ids).issubset(self.chunk_ids):
|
||||
return
|
||||
|
||||
async with self.chunk_id_lock:
|
||||
def remove_chunk(chunk_id: str):
|
||||
index = self.chunk_ids.index(chunk_id)
|
||||
self.index.remove_ids(np.array([index]))
|
||||
|
||||
|
@ -146,6 +148,10 @@ class FaissIndex(EmbeddingIndex):
|
|||
self.chunk_by_index = new_chunk_by_index
|
||||
self.chunk_ids.pop(index)
|
||||
|
||||
async with self.chunk_id_lock:
|
||||
for chunk_id in chunk_ids:
|
||||
remove_chunk(chunk_id)
|
||||
|
||||
await self._save_index()
|
||||
|
||||
async def query_vector(
|
||||
|
@ -297,8 +303,7 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPr
|
|||
|
||||
return await index.query_chunks(query, params)
|
||||
|
||||
async def delete_chunks(self, store_id: str, chunk_ids: list[str]) -> None:
|
||||
"""Delete a chunk from a faiss index"""
|
||||
async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None:
|
||||
"""Delete chunks from a faiss index"""
|
||||
faiss_index = self.cache[store_id].index
|
||||
for chunk_id in chunk_ids:
|
||||
await faiss_index.delete_chunk(chunk_id)
|
||||
await faiss_index.delete_chunks(chunks_for_deletion)
|
||||
|
|
|
@ -31,6 +31,7 @@ from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIV
|
|||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
RERANKER_TYPE_RRF,
|
||||
RERANKER_TYPE_WEIGHTED,
|
||||
ChunkForDeletion,
|
||||
EmbeddingIndex,
|
||||
VectorDBWithIndex,
|
||||
)
|
||||
|
@ -426,34 +427,35 @@ class SQLiteVecIndex(EmbeddingIndex):
|
|||
|
||||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||
|
||||
async def delete_chunk(self, chunk_id: str) -> None:
|
||||
async def delete_chunks(self, chunks_for_deletion: list[ChunkForDeletion]) -> None:
|
||||
"""Remove a chunk from the SQLite vector store."""
|
||||
chunk_ids = [c.chunk_id for c in chunks_for_deletion]
|
||||
|
||||
def _delete_chunk():
|
||||
def _delete_chunks():
|
||||
connection = _create_sqlite_connection(self.db_path)
|
||||
cur = connection.cursor()
|
||||
try:
|
||||
cur.execute("BEGIN TRANSACTION")
|
||||
|
||||
# Delete from metadata table
|
||||
cur.execute(f"DELETE FROM {self.metadata_table} WHERE id = ?", (chunk_id,))
|
||||
cur.execute(f"DELETE FROM {self.metadata_table} WHERE id = ANY(?)", (chunk_ids,))
|
||||
|
||||
# Delete from vector table
|
||||
cur.execute(f"DELETE FROM {self.vector_table} WHERE id = ?", (chunk_id,))
|
||||
cur.execute(f"DELETE FROM {self.vector_table} WHERE id = ANY(?)", (chunk_ids,))
|
||||
|
||||
# Delete from FTS table
|
||||
cur.execute(f"DELETE FROM {self.fts_table} WHERE id = ?", (chunk_id,))
|
||||
cur.execute(f"DELETE FROM {self.fts_table} WHERE id = ANY(?)", (chunk_ids,))
|
||||
|
||||
connection.commit()
|
||||
except Exception as e:
|
||||
connection.rollback()
|
||||
logger.error(f"Error deleting chunk {chunk_id}: {e}")
|
||||
logger.error(f"Error deleting chunks: {e}")
|
||||
raise
|
||||
finally:
|
||||
cur.close()
|
||||
connection.close()
|
||||
|
||||
await asyncio.to_thread(_delete_chunk)
|
||||
await asyncio.to_thread(_delete_chunks)
|
||||
|
||||
|
||||
class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate):
|
||||
|
@ -551,12 +553,10 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc
|
|||
raise VectorStoreNotFoundError(vector_db_id)
|
||||
return await index.query_chunks(query, params)
|
||||
|
||||
async def delete_chunks(self, store_id: str, chunk_ids: list[str]) -> None:
|
||||
"""Delete a chunk from a sqlite_vec index."""
|
||||
async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None:
|
||||
"""Delete chunks from a sqlite_vec index."""
|
||||
index = await self._get_and_cache_vector_db_index(store_id)
|
||||
if not index:
|
||||
raise VectorStoreNotFoundError(store_id)
|
||||
|
||||
for chunk_id in chunk_ids:
|
||||
# Use the index's delete_chunk method
|
||||
await index.index.delete_chunk(chunk_id)
|
||||
await index.index.delete_chunks(chunks_for_deletion)
|
||||
|
|
|
@ -342,6 +342,7 @@ See [Chroma's documentation](https://docs.trychroma.com/docs/overview/introducti
|
|||
""",
|
||||
),
|
||||
api_dependencies=[Api.inference],
|
||||
optional_api_dependencies=[Api.files],
|
||||
),
|
||||
InlineProviderSpec(
|
||||
api=Api.vector_io,
|
||||
|
@ -350,6 +351,7 @@ See [Chroma's documentation](https://docs.trychroma.com/docs/overview/introducti
|
|||
module="llama_stack.providers.inline.vector_io.chroma",
|
||||
config_class="llama_stack.providers.inline.vector_io.chroma.ChromaVectorIOConfig",
|
||||
api_dependencies=[Api.inference],
|
||||
optional_api_dependencies=[Api.files],
|
||||
description="""
|
||||
[Chroma](https://www.trychroma.com/) is an inline and remote vector
|
||||
database provider for Llama Stack. It allows you to store and query vectors directly within a Chroma database.
|
||||
|
@ -731,6 +733,7 @@ For more details on TLS configuration, refer to the [TLS setup guide](https://mi
|
|||
""",
|
||||
),
|
||||
api_dependencies=[Api.inference],
|
||||
optional_api_dependencies=[Api.files],
|
||||
),
|
||||
InlineProviderSpec(
|
||||
api=Api.vector_io,
|
||||
|
|
|
@ -26,6 +26,7 @@ from llama_stack.providers.utils.kvstore import kvstore_impl
|
|||
from llama_stack.providers.utils.kvstore.api import KVStore
|
||||
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
ChunkForDeletion,
|
||||
EmbeddingIndex,
|
||||
VectorDBWithIndex,
|
||||
)
|
||||
|
@ -115,8 +116,10 @@ class ChromaIndex(EmbeddingIndex):
|
|||
) -> QueryChunksResponse:
|
||||
raise NotImplementedError("Keyword search is not supported in Chroma")
|
||||
|
||||
async def delete_chunk(self, chunk_id: str) -> None:
|
||||
raise NotImplementedError("delete_chunk is not supported in Chroma")
|
||||
async def delete_chunks(self, chunks_for_deletion: list[ChunkForDeletion]) -> None:
|
||||
"""Delete a single chunk from the Chroma collection by its ID."""
|
||||
ids = [f"{chunk.document_id}:{chunk.chunk_id}" for chunk in chunks_for_deletion]
|
||||
await maybe_await(self.collection.delete(ids=ids))
|
||||
|
||||
async def query_hybrid(
|
||||
self,
|
||||
|
@ -144,6 +147,7 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
|||
self.cache = {}
|
||||
self.kvstore: KVStore | None = None
|
||||
self.vector_db_store = None
|
||||
self.files_api = files_api
|
||||
|
||||
async def initialize(self) -> None:
|
||||
self.kvstore = await kvstore_impl(self.config.kvstore)
|
||||
|
@ -227,5 +231,10 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
|||
self.cache[vector_db_id] = index
|
||||
return index
|
||||
|
||||
async def delete_chunks(self, store_id: str, chunk_ids: list[str]) -> None:
|
||||
raise NotImplementedError("OpenAI Vector Stores API is not supported in Chroma")
|
||||
async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None:
|
||||
"""Delete chunks from a Chroma vector store."""
|
||||
index = await self._get_and_cache_vector_db_index(store_id)
|
||||
if not index:
|
||||
raise ValueError(f"Vector DB {store_id} not found")
|
||||
|
||||
await index.index.delete_chunks(chunks_for_deletion)
|
||||
|
|
|
@ -28,6 +28,7 @@ from llama_stack.providers.utils.kvstore.api import KVStore
|
|||
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
RERANKER_TYPE_WEIGHTED,
|
||||
ChunkForDeletion,
|
||||
EmbeddingIndex,
|
||||
VectorDBWithIndex,
|
||||
)
|
||||
|
@ -287,14 +288,15 @@ class MilvusIndex(EmbeddingIndex):
|
|||
|
||||
return QueryChunksResponse(chunks=filtered_chunks, scores=filtered_scores)
|
||||
|
||||
async def delete_chunk(self, chunk_id: str) -> None:
|
||||
async def delete_chunks(self, chunks_for_deletion: list[ChunkForDeletion]) -> None:
|
||||
"""Remove a chunk from the Milvus collection."""
|
||||
chunk_ids_str = ",".join(f"'{c.chunk_id}'" for c in chunks_for_deletion)
|
||||
try:
|
||||
await asyncio.to_thread(
|
||||
self.client.delete, collection_name=self.collection_name, filter=f'chunk_id == "{chunk_id}"'
|
||||
self.client.delete, collection_name=self.collection_name, filter=f"chunk_id IN [{chunk_ids_str}]"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting chunk {chunk_id} from Milvus collection {self.collection_name}: {e}")
|
||||
logger.error(f"Error deleting chunks from Milvus collection {self.collection_name}: {e}")
|
||||
raise
|
||||
|
||||
|
||||
|
@ -420,12 +422,10 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
|||
|
||||
return await index.query_chunks(query, params)
|
||||
|
||||
async def delete_chunks(self, store_id: str, chunk_ids: list[str]) -> None:
|
||||
async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None:
|
||||
"""Delete a chunk from a milvus vector store."""
|
||||
index = await self._get_and_cache_vector_db_index(store_id)
|
||||
if not index:
|
||||
raise VectorStoreNotFoundError(store_id)
|
||||
|
||||
for chunk_id in chunk_ids:
|
||||
# Use the index's delete_chunk method
|
||||
await index.index.delete_chunk(chunk_id)
|
||||
await index.index.delete_chunks(chunks_for_deletion)
|
||||
|
|
|
@ -27,6 +27,7 @@ from llama_stack.providers.utils.kvstore import kvstore_impl
|
|||
from llama_stack.providers.utils.kvstore.api import KVStore
|
||||
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
ChunkForDeletion,
|
||||
EmbeddingIndex,
|
||||
VectorDBWithIndex,
|
||||
)
|
||||
|
@ -163,10 +164,11 @@ class PGVectorIndex(EmbeddingIndex):
|
|||
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
|
||||
cur.execute(f"DROP TABLE IF EXISTS {self.table_name}")
|
||||
|
||||
async def delete_chunk(self, chunk_id: str) -> None:
|
||||
async def delete_chunks(self, chunks_for_deletion: list[ChunkForDeletion]) -> None:
|
||||
"""Remove a chunk from the PostgreSQL table."""
|
||||
chunk_ids = [c.chunk_id for c in chunks_for_deletion]
|
||||
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
|
||||
cur.execute(f"DELETE FROM {self.table_name} WHERE id = %s", (chunk_id,))
|
||||
cur.execute(f"DELETE FROM {self.table_name} WHERE id = ANY(%s)", (chunk_ids,))
|
||||
|
||||
|
||||
class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate):
|
||||
|
@ -275,12 +277,10 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoco
|
|||
self.cache[vector_db_id] = VectorDBWithIndex(vector_db, index, self.inference_api)
|
||||
return self.cache[vector_db_id]
|
||||
|
||||
async def delete_chunks(self, store_id: str, chunk_ids: list[str]) -> None:
|
||||
async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None:
|
||||
"""Delete a chunk from a PostgreSQL vector store."""
|
||||
index = await self._get_and_cache_vector_db_index(store_id)
|
||||
if not index:
|
||||
raise VectorStoreNotFoundError(store_id)
|
||||
|
||||
for chunk_id in chunk_ids:
|
||||
# Use the index's delete_chunk method
|
||||
await index.index.delete_chunk(chunk_id)
|
||||
await index.index.delete_chunks(chunks_for_deletion)
|
||||
|
|
|
@ -29,6 +29,7 @@ from llama_stack.providers.inline.vector_io.qdrant import QdrantVectorIOConfig a
|
|||
from llama_stack.providers.utils.kvstore import KVStore, kvstore_impl
|
||||
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
ChunkForDeletion,
|
||||
EmbeddingIndex,
|
||||
VectorDBWithIndex,
|
||||
)
|
||||
|
@ -88,15 +89,16 @@ class QdrantIndex(EmbeddingIndex):
|
|||
|
||||
await self.client.upsert(collection_name=self.collection_name, points=points)
|
||||
|
||||
async def delete_chunk(self, chunk_id: str) -> None:
|
||||
async def delete_chunks(self, chunks_for_deletion: list[ChunkForDeletion]) -> None:
|
||||
"""Remove a chunk from the Qdrant collection."""
|
||||
chunk_ids = [convert_id(c.chunk_id) for c in chunks_for_deletion]
|
||||
try:
|
||||
await self.client.delete(
|
||||
collection_name=self.collection_name,
|
||||
points_selector=models.PointIdsList(points=[convert_id(chunk_id)]),
|
||||
points_selector=models.PointIdsList(points=chunk_ids),
|
||||
)
|
||||
except Exception as e:
|
||||
log.error(f"Error deleting chunk {chunk_id} from Qdrant collection {self.collection_name}: {e}")
|
||||
log.error(f"Error deleting chunks from Qdrant collection {self.collection_name}: {e}")
|
||||
raise
|
||||
|
||||
async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||
|
@ -266,10 +268,10 @@ class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
|||
async with self._qdrant_lock:
|
||||
await super().openai_attach_file_to_vector_store(vector_store_id, file_id, attributes, chunking_strategy)
|
||||
|
||||
async def delete_chunks(self, store_id: str, chunk_ids: list[str]) -> None:
|
||||
async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None:
|
||||
"""Delete chunks from a Qdrant vector store."""
|
||||
index = await self._get_and_cache_vector_db_index(store_id)
|
||||
if not index:
|
||||
raise ValueError(f"Vector DB {store_id} not found")
|
||||
for chunk_id in chunk_ids:
|
||||
await index.index.delete_chunk(chunk_id)
|
||||
|
||||
await index.index.delete_chunks(chunks_for_deletion)
|
||||
|
|
|
@ -38,7 +38,11 @@ from llama_stack.apis.vector_io import (
|
|||
VectorStoreSearchResponsePage,
|
||||
)
|
||||
from llama_stack.providers.utils.kvstore.api import KVStore
|
||||
from llama_stack.providers.utils.memory.vector_store import content_from_data_and_mime_type, make_overlapped_chunks
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
ChunkForDeletion,
|
||||
content_from_data_and_mime_type,
|
||||
make_overlapped_chunks,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -154,8 +158,8 @@ class OpenAIVectorStoreMixin(ABC):
|
|||
self.openai_vector_stores = await self._load_openai_vector_stores()
|
||||
|
||||
@abstractmethod
|
||||
async def delete_chunks(self, store_id: str, chunk_ids: list[str]) -> None:
|
||||
"""Delete a chunk from a vector store."""
|
||||
async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None:
|
||||
"""Delete chunks from a vector store."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
|
@ -767,7 +771,21 @@ class OpenAIVectorStoreMixin(ABC):
|
|||
|
||||
dict_chunks = await self._load_openai_vector_store_file_contents(vector_store_id, file_id)
|
||||
chunks = [Chunk.model_validate(c) for c in dict_chunks]
|
||||
await self.delete_chunks(vector_store_id, [str(c.chunk_id) for c in chunks if c.chunk_id])
|
||||
|
||||
# Create ChunkForDeletion objects with both chunk_id and document_id
|
||||
chunks_for_deletion = []
|
||||
for c in chunks:
|
||||
if c.chunk_id:
|
||||
document_id = c.metadata.get("document_id") or (
|
||||
c.chunk_metadata.document_id if c.chunk_metadata else None
|
||||
)
|
||||
if document_id:
|
||||
chunks_for_deletion.append(ChunkForDeletion(chunk_id=str(c.chunk_id), document_id=document_id))
|
||||
else:
|
||||
logger.warning(f"Chunk {c.chunk_id} has no document_id, skipping deletion")
|
||||
|
||||
if chunks_for_deletion:
|
||||
await self.delete_chunks(vector_store_id, chunks_for_deletion)
|
||||
|
||||
store_info = self.openai_vector_stores[vector_store_id].copy()
|
||||
|
||||
|
|
|
@ -16,6 +16,7 @@ from urllib.parse import unquote
|
|||
import httpx
|
||||
import numpy as np
|
||||
from numpy.typing import NDArray
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
URL,
|
||||
|
@ -34,6 +35,18 @@ from llama_stack.providers.utils.vector_io.vector_utils import generate_chunk_id
|
|||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ChunkForDeletion(BaseModel):
|
||||
"""Information needed to delete a chunk from a vector store.
|
||||
|
||||
:param chunk_id: The ID of the chunk to delete
|
||||
:param document_id: The ID of the document this chunk belongs to
|
||||
"""
|
||||
|
||||
chunk_id: str
|
||||
document_id: str
|
||||
|
||||
|
||||
# Constants for reranker types
|
||||
RERANKER_TYPE_RRF = "rrf"
|
||||
RERANKER_TYPE_WEIGHTED = "weighted"
|
||||
|
@ -232,7 +245,7 @@ class EmbeddingIndex(ABC):
|
|||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
async def delete_chunk(self, chunk_id: str):
|
||||
async def delete_chunks(self, chunks_for_deletion: list[ChunkForDeletion]):
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
|
|
|
@ -588,16 +588,19 @@ def test_openai_vector_store_list_files(compat_client_with_empty_stores, client_
|
|||
file_buffer.name = f"openai_test_{i}.txt"
|
||||
file = compat_client.files.create(file=file_buffer, purpose="assistants")
|
||||
|
||||
compat_client.vector_stores.files.create(
|
||||
response = compat_client.vector_stores.files.create(
|
||||
vector_store_id=vector_store.id,
|
||||
file_id=file.id,
|
||||
)
|
||||
assert response.status == "completed", (
|
||||
f"Failed to attach file {file.id} to vector store {vector_store.id}: {response=}"
|
||||
)
|
||||
file_ids.append(file.id)
|
||||
|
||||
files_list = compat_client.vector_stores.files.list(vector_store_id=vector_store.id)
|
||||
assert files_list
|
||||
assert files_list.object == "list"
|
||||
assert files_list.data
|
||||
assert files_list.data is not None
|
||||
assert not files_list.has_more
|
||||
assert len(files_list.data) == 3
|
||||
assert set(file_ids) == {file.id for file in files_list.data}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue