From 1db4800c9cdfffbc17e63c34ed6ca426f5acb928 Mon Sep 17 00:00:00 2001 From: Mustafa Elbehery Date: Mon, 8 Sep 2025 17:19:42 +0200 Subject: [PATCH 1/9] feat(client): migrate MilvusClient to AsyncMilvusClient MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The commit makes the follwing changes. - Import statements updated: MilvusClient → AsyncMilvusClient - Removed asyncio.to_thread() wrappers: All Milvus operations now use native async/await - Test compatibility: Mock objects and fixtures updated to work with AsyncMilvusClient Signed-off-by: Mustafa Elbehery --- .../remote/vector_io/milvus/milvus.py | 45 ++- tests/unit/providers/vector_io/conftest.py | 50 ++- .../providers/vector_io/remote/test_milvus.py | 339 ++++++++++++++++++ .../test_vector_io_openai_vector_stores.py | 6 +- 4 files changed, 413 insertions(+), 27 deletions(-) create mode 100644 tests/unit/providers/vector_io/remote/test_milvus.py diff --git a/src/llama_stack/providers/remote/vector_io/milvus/milvus.py b/src/llama_stack/providers/remote/vector_io/milvus/milvus.py index 73339b5be..d0bee8965 100644 --- a/src/llama_stack/providers/remote/vector_io/milvus/milvus.py +++ b/src/llama_stack/providers/remote/vector_io/milvus/milvus.py @@ -4,12 +4,11 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import asyncio import os from typing import Any from numpy.typing import NDArray -from pymilvus import AnnSearchRequest, DataType, Function, FunctionType, MilvusClient, RRFRanker, WeightedRanker +from pymilvus import AnnSearchRequest, AsyncMilvusClient, DataType, Function, FunctionType, RRFRanker, WeightedRanker from llama_stack.apis.common.errors import VectorStoreNotFoundError from llama_stack.apis.files import Files @@ -44,7 +43,11 @@ OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX = f"openai_vector_stores_files_conten class MilvusIndex(EmbeddingIndex): def __init__( - self, client: MilvusClient, collection_name: str, consistency_level="Strong", kvstore: KVStore | None = None + self, + client: AsyncMilvusClient, + collection_name: str, + consistency_level="Strong", + kvstore: KVStore | None = None, ): self.client = client self.collection_name = sanitize_collection_name(collection_name) @@ -57,15 +60,15 @@ class MilvusIndex(EmbeddingIndex): pass async def delete(self): - if await asyncio.to_thread(self.client.has_collection, self.collection_name): - await asyncio.to_thread(self.client.drop_collection, collection_name=self.collection_name) + if await self.client.has_collection(self.collection_name): + await self.client.drop_collection(collection_name=self.collection_name) async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray): assert len(chunks) == len(embeddings), ( f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}" ) - if not await asyncio.to_thread(self.client.has_collection, self.collection_name): + if not await self.client.has_collection(self.collection_name): logger.info(f"Creating new collection {self.collection_name} with nullable sparse field") # Create schema for vector search schema = self.client.create_schema() @@ -96,8 +99,7 @@ class MilvusIndex(EmbeddingIndex): ) schema.add_function(bm25_function) - await asyncio.to_thread( - self.client.create_collection, + await self.client.create_collection( self.collection_name, schema=schema, index_params=index_params, @@ -116,14 +118,16 @@ class MilvusIndex(EmbeddingIndex): } ) try: - await asyncio.to_thread(self.client.insert, self.collection_name, data=data) + await self.client.insert( + self.collection_name, + data=data, + ) except Exception as e: logger.error(f"Error inserting chunks into Milvus collection {self.collection_name}: {e}") raise e async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse: - search_res = await asyncio.to_thread( - self.client.search, + search_res = await self.client.search( collection_name=self.collection_name, data=[embedding], anns_field="vector", @@ -141,8 +145,7 @@ class MilvusIndex(EmbeddingIndex): """ try: # Use Milvus's built-in BM25 search - search_res = await asyncio.to_thread( - self.client.search, + search_res = await self.client.search( collection_name=self.collection_name, data=[query_string], # Raw text query anns_field="sparse", # Use sparse field for BM25 @@ -178,8 +181,7 @@ class MilvusIndex(EmbeddingIndex): Fallback to simple text search when BM25 search is not available. """ # Simple text search using content field - search_res = await asyncio.to_thread( - self.client.query, + search_res = await self.client.query( collection_name=self.collection_name, filter='content like "%{content}%"', filter_params={"content": query_string}, @@ -226,8 +228,7 @@ class MilvusIndex(EmbeddingIndex): impact_factor = (reranker_params or {}).get("impact_factor", 60.0) rerank = RRFRanker(impact_factor) - search_res = await asyncio.to_thread( - self.client.hybrid_search, + search_res = await self.client.hybrid_search( collection_name=self.collection_name, reqs=search_requests, ranker=rerank, @@ -253,9 +254,7 @@ class MilvusIndex(EmbeddingIndex): try: # Use IN clause with square brackets and single quotes for VARCHAR field chunk_ids_str = ", ".join(f"'{chunk_id}'" for chunk_id in chunk_ids) - await asyncio.to_thread( - self.client.delete, collection_name=self.collection_name, filter=f"chunk_id in [{chunk_ids_str}]" - ) + await self.client.delete(collection_name=self.collection_name, filter=f"chunk_id in [{chunk_ids_str}]") except Exception as e: logger.error(f"Error deleting chunks from Milvus collection {self.collection_name}: {e}") raise @@ -297,17 +296,17 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoc self.cache[vector_store.identifier] = index if isinstance(self.config, RemoteMilvusVectorIOConfig): logger.info(f"Connecting to Milvus server at {self.config.uri}") - self.client = MilvusClient(**self.config.model_dump(exclude_none=True)) + self.client = AsyncMilvusClient(**self.config.model_dump(exclude_none=True)) else: logger.info(f"Connecting to Milvus Lite at: {self.config.db_path}") uri = os.path.expanduser(self.config.db_path) - self.client = MilvusClient(uri=uri) + self.client = AsyncMilvusClient(uri=uri) # Load existing OpenAI vector stores into the in-memory cache await self.initialize_openai_vector_stores() async def shutdown(self) -> None: - self.client.close() + await self.client.close() # Clean up mixin resources (file batch tasks) await super().shutdown() diff --git a/tests/unit/providers/vector_io/conftest.py b/tests/unit/providers/vector_io/conftest.py index 5e56ea417..0960a9c52 100644 --- a/tests/unit/providers/vector_io/conftest.py +++ b/tests/unit/providers/vector_io/conftest.py @@ -9,20 +9,26 @@ from unittest.mock import AsyncMock, MagicMock, patch import numpy as np import pytest +from chromadb import PersistentClient +from pymilvus import AsyncMilvusClient, connections +from llama_stack.apis.vector_dbs import VectorDB from llama_stack.apis.vector_io import Chunk, ChunkMetadata, QueryChunksResponse from llama_stack.apis.vector_stores import VectorStore from llama_stack.core.storage.datatypes import KVStoreReference, SqliteKVStoreConfig from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig from llama_stack.providers.inline.vector_io.faiss.faiss import FaissIndex, FaissVectorIOAdapter +from llama_stack.providers.inline.vector_io.milvus.config import MilvusVectorIOConfig from llama_stack.providers.inline.vector_io.sqlite_vec import SQLiteVectorIOConfig from llama_stack.providers.inline.vector_io.sqlite_vec.sqlite_vec import SQLiteVecIndex, SQLiteVecVectorIOAdapter +from llama_stack.providers.remote.vector_io.milvus.milvus import MilvusIndex, MilvusVectorIOAdapter from llama_stack.providers.remote.vector_io.pgvector.config import PGVectorVectorIOConfig from llama_stack.providers.remote.vector_io.pgvector.pgvector import PGVectorIndex, PGVectorVectorIOAdapter from llama_stack.providers.utils.kvstore import register_kvstore_backends EMBEDDING_DIMENSION = 768 COLLECTION_PREFIX = "test_collection" +MILVUS_ALIAS = "test_milvus" @pytest.fixture(params=["sqlite_vec", "faiss", "pgvector"]) @@ -141,7 +147,7 @@ async def sqlite_vec_vec_index(embedding_dimension, tmp_path_factory): await index.initialize() index.db_path = db_path yield index - index.delete() + await index.delete() @pytest.fixture @@ -170,6 +176,48 @@ async def sqlite_vec_adapter(sqlite_vec_db_path, unique_kvstore_config, mock_inf await adapter.shutdown() +@pytest.fixture(scope="session") +def milvus_vec_db_path(tmp_path_factory): + db_path = str(tmp_path_factory.getbasetemp() / "test_milvus.db") + return db_path + + +@pytest.fixture +async def milvus_vec_index(milvus_vec_db_path, embedding_dimension): + client = AsyncMilvusClient(milvus_vec_db_path) + name = f"{COLLECTION_PREFIX}_{np.random.randint(1e6)}" + connections.connect(alias=MILVUS_ALIAS, uri=milvus_vec_db_path) + index = MilvusIndex(client, name, consistency_level="Strong") + index.db_path = milvus_vec_db_path + yield index + # Proper cleanup: close the async client + await client.close() + + +@pytest.fixture +async def milvus_vec_adapter(milvus_vec_db_path, mock_inference_api): + config = MilvusVectorIOConfig( + db_path=milvus_vec_db_path, + kvstore=SqliteKVStoreConfig(), + ) + adapter = MilvusVectorIOAdapter( + config=config, + inference_api=mock_inference_api, + files_api=None, + ) + await adapter.initialize() + await adapter.register_vector_db( + VectorDB( + identifier=adapter.metadata_collection_name, + provider_id="test_provider", + embedding_model="test_model", + embedding_dimension=128, + ) + ) + yield adapter + await adapter.shutdown() + + @pytest.fixture def faiss_vec_db_path(tmp_path_factory): db_path = str(tmp_path_factory.getbasetemp() / "test_faiss.db") diff --git a/tests/unit/providers/vector_io/remote/test_milvus.py b/tests/unit/providers/vector_io/remote/test_milvus.py new file mode 100644 index 000000000..6e8817366 --- /dev/null +++ b/tests/unit/providers/vector_io/remote/test_milvus.py @@ -0,0 +1,339 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from unittest.mock import AsyncMock, MagicMock, patch + +import numpy as np +import pytest + +from llama_stack.apis.vector_io import QueryChunksResponse + +# Mock the entire pymilvus module +pymilvus_mock = MagicMock() +pymilvus_mock.DataType = MagicMock() +pymilvus_mock.AsyncMilvusClient = MagicMock +pymilvus_mock.RRFRanker = MagicMock +pymilvus_mock.WeightedRanker = MagicMock +pymilvus_mock.AnnSearchRequest = MagicMock + +# Apply the mock before importing MilvusIndex +with patch.dict("sys.modules", {"pymilvus": pymilvus_mock}): + from llama_stack.providers.remote.vector_io.milvus.milvus import MilvusIndex + +# This test is a unit test for the MilvusVectorIOAdapter class. This should only contain +# tests which are specific to this class. More general (API-level) tests should be placed in +# tests/integration/vector_io/ +# +# How to run this test: +# +# pytest tests/unit/providers/vector_io/test_milvus.py \ +# -v -s --tb=short --disable-warnings --asyncio-mode=auto + +MILVUS_PROVIDER = "milvus" + + +@pytest.fixture +async def mock_milvus_client() -> MagicMock: + """Create a mock Milvus client with common method behaviors.""" + client = MagicMock() + + # Mock async collection operations + client.has_collection = AsyncMock(return_value=False) # Initially no collection + client.create_collection = AsyncMock(return_value=None) + client.drop_collection = AsyncMock(return_value=None) + + # Mock async insert operation + client.insert = AsyncMock(return_value={"insert_count": 10}) + + # Mock async search operation - return mock results (data should be dict, not JSON string) + client.search = AsyncMock( + return_value=[ + [ + { + "id": 0, + "distance": 0.1, + "entity": {"chunk_content": {"content": "mock chunk 1", "metadata": {"document_id": "doc1"}}}, + }, + { + "id": 1, + "distance": 0.2, + "entity": {"chunk_content": {"content": "mock chunk 2", "metadata": {"document_id": "doc2"}}}, + }, + ] + ] + ) + + # Mock async query operation for keyword search (data should be dict, not JSON string) + client.query = AsyncMock( + return_value=[ + { + "chunk_id": "chunk1", + "chunk_content": {"content": "mock chunk 1", "metadata": {"document_id": "doc1"}}, + "score": 0.9, + }, + { + "chunk_id": "chunk2", + "chunk_content": {"content": "mock chunk 2", "metadata": {"document_id": "doc2"}}, + "score": 0.8, + }, + { + "chunk_id": "chunk3", + "chunk_content": {"content": "mock chunk 3", "metadata": {"document_id": "doc3"}}, + "score": 0.7, + }, + ] + ) + + # Mock async hybrid_search operation + client.hybrid_search = AsyncMock(return_value=[]) + + # Mock async delete operation + client.delete = AsyncMock(return_value=None) + + # Mock async close operation + client.close = AsyncMock(return_value=None) + + return client + + +@pytest.fixture +async def milvus_index(mock_milvus_client): + """Create a MilvusIndex with mocked client.""" + index = MilvusIndex(client=mock_milvus_client, collection_name="test_collection") + yield index + # No real cleanup needed since we're using mocks + + +async def test_add_chunks(milvus_index, sample_chunks, sample_embeddings, mock_milvus_client): + # Setup: collection doesn't exist initially, then exists after creation + mock_milvus_client.has_collection.side_effect = [False, True] + + await milvus_index.add_chunks(sample_chunks, sample_embeddings) + + # Verify collection was created and data was inserted + mock_milvus_client.create_collection.assert_called_once() + mock_milvus_client.insert.assert_called_once() + + # Verify the insert call had the right number of chunks + insert_call = mock_milvus_client.insert.call_args + assert len(insert_call[1]["data"]) == len(sample_chunks) + + +async def test_query_chunks_vector( + milvus_index, sample_chunks, sample_embeddings, embedding_dimension, mock_milvus_client +): + # Setup: Add chunks first + mock_milvus_client.has_collection.return_value = True + await milvus_index.add_chunks(sample_chunks, sample_embeddings) + + # Test vector search + query_embedding = np.random.rand(embedding_dimension).astype(np.float32) + response = await milvus_index.query_vector(query_embedding, k=2, score_threshold=0.0) + + assert isinstance(response, QueryChunksResponse) + assert len(response.chunks) == 2 + mock_milvus_client.search.assert_called_once() + + +async def test_query_chunks_keyword_search(milvus_index, sample_chunks, sample_embeddings, mock_milvus_client): + mock_milvus_client.has_collection.return_value = True + await milvus_index.add_chunks(sample_chunks, sample_embeddings) + + # Test keyword search + query_string = "Sentence 5" + response = await milvus_index.query_keyword(query_string=query_string, k=2, score_threshold=0.0) + + assert isinstance(response, QueryChunksResponse) + assert len(response.chunks) == 2 + + +async def test_bm25_fallback_to_simple_search(milvus_index, sample_chunks, sample_embeddings, mock_milvus_client): + """Test that when BM25 search fails, the system falls back to simple text search.""" + mock_milvus_client.has_collection.return_value = True + await milvus_index.add_chunks(sample_chunks, sample_embeddings) + + # Force BM25 search to fail + mock_milvus_client.search.side_effect = Exception("BM25 search not available") + + # Mock simple text search results + mock_milvus_client.query.return_value = [ + { + "chunk_id": "chunk1", + "chunk_content": {"content": "Python programming language", "metadata": {"document_id": "doc1"}}, + }, + { + "chunk_id": "chunk2", + "chunk_content": {"content": "Machine learning algorithms", "metadata": {"document_id": "doc2"}}, + }, + ] + + # Test keyword search that should fall back to simple text search + query_string = "Python" + response = await milvus_index.query_keyword(query_string=query_string, k=3, score_threshold=0.0) + + # Verify response structure + assert isinstance(response, QueryChunksResponse) + assert len(response.chunks) > 0, "Fallback search should return results" + + # Verify that simple text search was used (query method called instead of search) + mock_milvus_client.query.assert_called_once() + mock_milvus_client.search.assert_called_once() # Called once but failed + + # Verify the query uses parameterized filter with filter_params + query_call_args = mock_milvus_client.query.call_args + assert "filter" in query_call_args[1], "Query should include filter for text search" + assert "filter_params" in query_call_args[1], "Query should use parameterized filter" + assert query_call_args[1]["filter_params"]["content"] == "Python", "Filter params should contain the search term" + + # Verify all returned chunks have score 1.0 (simple binary scoring) + assert all(score == 1.0 for score in response.scores), "Simple text search should use binary scoring" + + +async def test_delete_collection(milvus_index, mock_milvus_client): + # Test collection deletion + mock_milvus_client.has_collection.return_value = True + + await milvus_index.delete() + + mock_milvus_client.drop_collection.assert_called_once_with(collection_name=milvus_index.collection_name) + + +async def test_query_hybrid_search_rrf( + milvus_index, sample_chunks, sample_embeddings, embedding_dimension, mock_milvus_client +): + """Test hybrid search with RRF reranker.""" + mock_milvus_client.has_collection.return_value = True + await milvus_index.add_chunks(sample_chunks, sample_embeddings) + + # Mock hybrid search results + mock_milvus_client.hybrid_search.return_value = [ + [ + { + "id": 0, + "distance": 0.1, + "entity": {"chunk_content": {"content": "mock chunk 1", "metadata": {"document_id": "doc1"}}}, + }, + { + "id": 1, + "distance": 0.2, + "entity": {"chunk_content": {"content": "mock chunk 2", "metadata": {"document_id": "doc2"}}}, + }, + ] + ] + + # Test hybrid search with RRF reranker + query_embedding = np.random.rand(embedding_dimension).astype(np.float32) + query_string = "test query" + response = await milvus_index.query_hybrid( + embedding=query_embedding, + query_string=query_string, + k=2, + score_threshold=0.0, + reranker_type="rrf", + reranker_params={"impact_factor": 60.0}, + ) + + assert isinstance(response, QueryChunksResponse) + assert len(response.chunks) == 2 + assert len(response.scores) == 2 + + # Verify hybrid search was called with correct parameters + mock_milvus_client.hybrid_search.assert_called_once() + call_args = mock_milvus_client.hybrid_search.call_args + + # Check that the request contains both vector and BM25 search requests + reqs = call_args[1]["reqs"] + assert len(reqs) == 2 + assert reqs[0].anns_field == "vector" + assert reqs[1].anns_field == "sparse" + ranker = call_args[1]["ranker"] + assert ranker is not None + + +async def test_query_hybrid_search_weighted( + milvus_index, sample_chunks, sample_embeddings, embedding_dimension, mock_milvus_client +): + """Test hybrid search with weighted reranker.""" + mock_milvus_client.has_collection.return_value = True + await milvus_index.add_chunks(sample_chunks, sample_embeddings) + + # Mock hybrid search results + mock_milvus_client.hybrid_search.return_value = [ + [ + { + "id": 0, + "distance": 0.1, + "entity": {"chunk_content": {"content": "mock chunk 1", "metadata": {"document_id": "doc1"}}}, + }, + { + "id": 1, + "distance": 0.2, + "entity": {"chunk_content": {"content": "mock chunk 2", "metadata": {"document_id": "doc2"}}}, + }, + ] + ] + + # Test hybrid search with weighted reranker + query_embedding = np.random.rand(embedding_dimension).astype(np.float32) + query_string = "test query" + response = await milvus_index.query_hybrid( + embedding=query_embedding, + query_string=query_string, + k=2, + score_threshold=0.0, + reranker_type="weighted", + reranker_params={"alpha": 0.7}, + ) + + assert isinstance(response, QueryChunksResponse) + assert len(response.chunks) == 2 + assert len(response.scores) == 2 + + # Verify hybrid search was called with correct parameters + mock_milvus_client.hybrid_search.assert_called_once() + call_args = mock_milvus_client.hybrid_search.call_args + ranker = call_args[1]["ranker"] + assert ranker is not None + + +async def test_query_hybrid_search_default_rrf( + milvus_index, sample_chunks, sample_embeddings, embedding_dimension, mock_milvus_client +): + """Test hybrid search with default RRF reranker (no reranker_type specified).""" + mock_milvus_client.has_collection.return_value = True + await milvus_index.add_chunks(sample_chunks, sample_embeddings) + + # Mock hybrid search results + mock_milvus_client.hybrid_search.return_value = [ + [ + { + "id": 0, + "distance": 0.1, + "entity": {"chunk_content": {"content": "mock chunk 1", "metadata": {"document_id": "doc1"}}}, + }, + ] + ] + + # Test hybrid search with default reranker (should be RRF) + query_embedding = np.random.rand(embedding_dimension).astype(np.float32) + query_string = "test query" + response = await milvus_index.query_hybrid( + embedding=query_embedding, + query_string=query_string, + k=1, + score_threshold=0.0, + reranker_type="unknown_type", # Should default to RRF + reranker_params=None, # Should use default impact_factor + ) + + assert isinstance(response, QueryChunksResponse) + assert len(response.chunks) == 1 + + # Verify hybrid search was called with RRF reranker + mock_milvus_client.hybrid_search.assert_called_once() + call_args = mock_milvus_client.hybrid_search.call_args + ranker = call_args[1]["ranker"] + assert ranker is not None diff --git a/tests/unit/providers/vector_io/test_vector_io_openai_vector_stores.py b/tests/unit/providers/vector_io/test_vector_io_openai_vector_stores.py index 121623e1b..2143f5ccc 100644 --- a/tests/unit/providers/vector_io/test_vector_io_openai_vector_stores.py +++ b/tests/unit/providers/vector_io/test_vector_io_openai_vector_stores.py @@ -48,12 +48,12 @@ async def test_initialize_index(vector_index): async def test_add_chunks_query_vector(vector_index, sample_chunks, sample_embeddings): - vector_index.delete() - vector_index.initialize() + await vector_index.delete() + await vector_index.initialize() await vector_index.add_chunks(sample_chunks, sample_embeddings) resp = await vector_index.query_vector(sample_embeddings[0], k=1, score_threshold=-1) assert resp.chunks[0].content == sample_chunks[0].content - vector_index.delete() + await vector_index.delete() async def test_chunk_id_conflict(vector_index, sample_chunks, embedding_dimension): From 6b299f24af775ca3a45a1472577f04fc9cc2870f Mon Sep 17 00:00:00 2001 From: Mustafa Elbehery Date: Mon, 8 Sep 2025 18:13:03 +0200 Subject: [PATCH 2/9] fix(test): chrome db test fails due to reusing deleted collection Signed-off-by: Mustafa Elbehery --- .../providers/remote/vector_io/chroma/chroma.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/src/llama_stack/providers/remote/vector_io/chroma/chroma.py b/src/llama_stack/providers/remote/vector_io/chroma/chroma.py index 97e2244b8..7a74d27df 100644 --- a/src/llama_stack/providers/remote/vector_io/chroma/chroma.py +++ b/src/llama_stack/providers/remote/vector_io/chroma/chroma.py @@ -48,10 +48,11 @@ class ChromaIndex(EmbeddingIndex): def __init__(self, client: ChromaClientType, collection, kvstore: KVStore | None = None): self.client = client self.collection = collection + self.collection_name = collection.name self.kvstore = kvstore async def initialize(self): - pass + self.collection = await maybe_await(self.client.get_or_create_collection(self.collection_name)) async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray): assert len(chunks) == len(embeddings), ( @@ -92,7 +93,13 @@ class ChromaIndex(EmbeddingIndex): return QueryChunksResponse(chunks=chunks, scores=scores) async def delete(self): - await maybe_await(self.client.delete_collection(self.collection.name)) + try: + await maybe_await(self.client.delete_collection(self.collection.name)) + except Exception as e: + if "does not exists" in str(e): + log.warning(f"Collection {self.collection.name} already deleted") + else: + raise async def query_keyword(self, query_string: str, k: int, score_threshold: float) -> QueryChunksResponse: raise NotImplementedError("Keyword search is not supported in Chroma") From 6aabf805661be663b96bab61c1fbc7074aaeb3ac Mon Sep 17 00:00:00 2001 From: Mustafa Elbehery Date: Mon, 8 Sep 2025 18:28:19 +0200 Subject: [PATCH 3/9] chore: remove irrelevant comments Signed-off-by: Mustafa Elbehery --- tests/unit/providers/vector_io/conftest.py | 2 +- tests/unit/providers/vector_io/remote/test_milvus.py | 5 +---- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/tests/unit/providers/vector_io/conftest.py b/tests/unit/providers/vector_io/conftest.py index 0960a9c52..7426219c1 100644 --- a/tests/unit/providers/vector_io/conftest.py +++ b/tests/unit/providers/vector_io/conftest.py @@ -190,7 +190,7 @@ async def milvus_vec_index(milvus_vec_db_path, embedding_dimension): index = MilvusIndex(client, name, consistency_level="Strong") index.db_path = milvus_vec_db_path yield index - # Proper cleanup: close the async client + await client.close() diff --git a/tests/unit/providers/vector_io/remote/test_milvus.py b/tests/unit/providers/vector_io/remote/test_milvus.py index 6e8817366..04bac71c2 100644 --- a/tests/unit/providers/vector_io/remote/test_milvus.py +++ b/tests/unit/providers/vector_io/remote/test_milvus.py @@ -48,7 +48,7 @@ async def mock_milvus_client() -> MagicMock: # Mock async insert operation client.insert = AsyncMock(return_value={"insert_count": 10}) - # Mock async search operation - return mock results (data should be dict, not JSON string) + # Mock async search operation client.search = AsyncMock( return_value=[ [ @@ -87,13 +87,10 @@ async def mock_milvus_client() -> MagicMock: ] ) - # Mock async hybrid_search operation client.hybrid_search = AsyncMock(return_value=[]) - # Mock async delete operation client.delete = AsyncMock(return_value=None) - # Mock async close operation client.close = AsyncMock(return_value=None) return client From f94d6146316d01e805fb8e375c9638f28c76f134 Mon Sep 17 00:00:00 2001 From: Mustafa Elbehery Date: Mon, 8 Sep 2025 22:00:59 +0200 Subject: [PATCH 4/9] fix(integration): init AsyncMilvusClient before MilvusIndex Signed-off-by: Mustafa Elbehery --- .../remote/vector_io/milvus/milvus.py | 100 ++++++++++++++---- .../providers/vector_io/remote/test_milvus.py | 3 - 2 files changed, 81 insertions(+), 22 deletions(-) diff --git a/src/llama_stack/providers/remote/vector_io/milvus/milvus.py b/src/llama_stack/providers/remote/vector_io/milvus/milvus.py index d0bee8965..fdee58206 100644 --- a/src/llama_stack/providers/remote/vector_io/milvus/milvus.py +++ b/src/llama_stack/providers/remote/vector_io/milvus/milvus.py @@ -48,11 +48,13 @@ class MilvusIndex(EmbeddingIndex): collection_name: str, consistency_level="Strong", kvstore: KVStore | None = None, + parent_adapter=None, ): self.client = client self.collection_name = sanitize_collection_name(collection_name) self.consistency_level = consistency_level self.kvstore = kvstore + self._parent_adapter = parent_adapter async def initialize(self): # MilvusIndex does not require explicit initialization @@ -60,15 +62,36 @@ class MilvusIndex(EmbeddingIndex): pass async def delete(self): - if await self.client.has_collection(self.collection_name): - await self.client.drop_collection(collection_name=self.collection_name) + try: + if await self.client.has_collection(self.collection_name): + await self.client.drop_collection(collection_name=self.collection_name) + except Exception as e: + logger.warning(f"Failed to check or delete collection {self.collection_name}: {e}") async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray): assert len(chunks) == len(embeddings), ( f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}" ) - if not await self.client.has_collection(self.collection_name): + try: + collection_exists = await self.client.has_collection(self.collection_name) + except Exception as e: + logger.error(f"Failed to check collection existence: {self.collection_name} ({e})") + # If it's an event loop issue, try to recreate the client + if "attached to a different loop" in str(e): + logger.warning("Recreating client due to event loop issue") + + if hasattr(self, "_parent_adapter"): + await self._parent_adapter._recreate_client() + collection_exists = await self.client.has_collection(self.collection_name) + else: + # Assume collection doesn't exist if we can't check + collection_exists = False + else: + # Assume collection doesn't exist if we can't check due to other issues + collection_exists = False + + if not collection_exists: logger.info(f"Creating new collection {self.collection_name} with nullable sparse field") # Create schema for vector search schema = self.client.create_schema() @@ -99,12 +122,16 @@ class MilvusIndex(EmbeddingIndex): ) schema.add_function(bm25_function) - await self.client.create_collection( - self.collection_name, - schema=schema, - index_params=index_params, - consistency_level=self.consistency_level, - ) + try: + await self.client.create_collection( + self.collection_name, + schema=schema, + index_params=index_params, + consistency_level=self.consistency_level, + ) + except Exception as e: + logger.error(f"Failed to create collection {self.collection_name}: {e}") + raise e data = [] for chunk, embedding in zip(chunks, embeddings, strict=False): @@ -277,6 +304,15 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoc async def initialize(self) -> None: self.kvstore = await kvstore_impl(self.config.persistence) + + if isinstance(self.config, RemoteMilvusVectorIOConfig): + logger.info(f"Connecting to Milvus server at {self.config.uri}") + self.client = AsyncMilvusClient(**self.config.model_dump(exclude_none=True)) + else: + logger.info(f"Connecting to Milvus Lite at: {self.config.db_path}") + uri = os.path.expanduser(self.config.db_path) + self.client = AsyncMilvusClient(uri=uri) + start_key = VECTOR_DBS_PREFIX end_key = f"{VECTOR_DBS_PREFIX}\xff" stored_vector_stores = await self.kvstore.values_in_range(start_key, end_key) @@ -290,26 +326,41 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoc collection_name=vector_store.identifier, consistency_level=self.config.consistency_level, kvstore=self.kvstore, + parent_adapter=self, ), inference_api=self.inference_api, ) self.cache[vector_store.identifier] = index - if isinstance(self.config, RemoteMilvusVectorIOConfig): - logger.info(f"Connecting to Milvus server at {self.config.uri}") - self.client = AsyncMilvusClient(**self.config.model_dump(exclude_none=True)) - else: - logger.info(f"Connecting to Milvus Lite at: {self.config.db_path}") - uri = os.path.expanduser(self.config.db_path) - self.client = AsyncMilvusClient(uri=uri) # Load existing OpenAI vector stores into the in-memory cache await self.initialize_openai_vector_stores() async def shutdown(self) -> None: - await self.client.close() + if self.client: + await self.client.close() # Clean up mixin resources (file batch tasks) await super().shutdown() + async def _recreate_client(self) -> None: + """Recreate the AsyncMilvusClient when event loop issues occur""" + try: + if self.client: + await self.client.close() + except Exception as e: + logger.warning(f"Error closing old client: {e}") + + if isinstance(self.config, RemoteMilvusVectorIOConfig): + logger.info(f"Recreating connection to Milvus server at {self.config.uri}") + self.client = AsyncMilvusClient(**self.config.model_dump(exclude_none=True)) + else: + logger.info(f"Recreating connection to Milvus Lite at: {self.config.db_path}") + uri = os.path.expanduser(self.config.db_path) + self.client = AsyncMilvusClient(uri=uri) + + for index_wrapper in self.cache.values(): + if hasattr(index_wrapper, "index") and hasattr(index_wrapper.index, "client"): + index_wrapper.index.client = self.client + async def register_vector_store(self, vector_store: VectorStore) -> None: if isinstance(self.config, RemoteMilvusVectorIOConfig): consistency_level = self.config.consistency_level @@ -317,7 +368,13 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoc consistency_level = "Strong" index = VectorStoreWithIndex( vector_store=vector_store, - index=MilvusIndex(self.client, vector_store.identifier, consistency_level=consistency_level), + index=MilvusIndex( + client=self.client, + collection_name=vector_store.identifier, + consistency_level=consistency_level, + kvstore=self.kvstore, + parent_adapter=self, + ), inference_api=self.inference_api, ) @@ -339,7 +396,12 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoc vector_store = VectorStore.model_validate_json(vector_store_data) index = VectorStoreWithIndex( vector_store=vector_store, - index=MilvusIndex(client=self.client, collection_name=vector_store.identifier, kvstore=self.kvstore), + index=MilvusIndex( + client=self.client, + collection_name=vector_store.identifier, + kvstore=self.kvstore, + parent_adapter=self, + ), inference_api=self.inference_api, ) self.cache[vector_store_id] = index diff --git a/tests/unit/providers/vector_io/remote/test_milvus.py b/tests/unit/providers/vector_io/remote/test_milvus.py index 04bac71c2..25374e617 100644 --- a/tests/unit/providers/vector_io/remote/test_milvus.py +++ b/tests/unit/providers/vector_io/remote/test_milvus.py @@ -40,15 +40,12 @@ async def mock_milvus_client() -> MagicMock: """Create a mock Milvus client with common method behaviors.""" client = MagicMock() - # Mock async collection operations client.has_collection = AsyncMock(return_value=False) # Initially no collection client.create_collection = AsyncMock(return_value=None) client.drop_collection = AsyncMock(return_value=None) - # Mock async insert operation client.insert = AsyncMock(return_value={"insert_count": 10}) - # Mock async search operation client.search = AsyncMock( return_value=[ [ From c3ce2df439a0c37ed57dca78b55c284121ca6b19 Mon Sep 17 00:00:00 2001 From: Mustafa Elbehery Date: Tue, 9 Sep 2025 21:49:35 +0200 Subject: [PATCH 5/9] Revert "fix(test): chrome db test fails due to reusing deleted collection" This reverts commit aac26e8c6c2965fa6a5d2918c2b92ea34b074c46. --- .../providers/remote/vector_io/chroma/chroma.py | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/src/llama_stack/providers/remote/vector_io/chroma/chroma.py b/src/llama_stack/providers/remote/vector_io/chroma/chroma.py index 7a74d27df..97e2244b8 100644 --- a/src/llama_stack/providers/remote/vector_io/chroma/chroma.py +++ b/src/llama_stack/providers/remote/vector_io/chroma/chroma.py @@ -48,11 +48,10 @@ class ChromaIndex(EmbeddingIndex): def __init__(self, client: ChromaClientType, collection, kvstore: KVStore | None = None): self.client = client self.collection = collection - self.collection_name = collection.name self.kvstore = kvstore async def initialize(self): - self.collection = await maybe_await(self.client.get_or_create_collection(self.collection_name)) + pass async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray): assert len(chunks) == len(embeddings), ( @@ -93,13 +92,7 @@ class ChromaIndex(EmbeddingIndex): return QueryChunksResponse(chunks=chunks, scores=scores) async def delete(self): - try: - await maybe_await(self.client.delete_collection(self.collection.name)) - except Exception as e: - if "does not exists" in str(e): - log.warning(f"Collection {self.collection.name} already deleted") - else: - raise + await maybe_await(self.client.delete_collection(self.collection.name)) async def query_keyword(self, query_string: str, k: int, score_threshold: float) -> QueryChunksResponse: raise NotImplementedError("Keyword search is not supported in Chroma") From 2347db868d9486a031461801bba2c39541536de3 Mon Sep 17 00:00:00 2001 From: Mustafa Elbehery Date: Tue, 9 Sep 2025 22:20:02 +0200 Subject: [PATCH 6/9] refactor(client): replace all AsyncMilvusClient usage of has_collection() with list_collections() Signed-off-by: Mustafa Elbehery --- .../remote/vector_io/milvus/milvus.py | 9 ++++++--- .../providers/vector_io/remote/test_milvus.py | 18 +++++++++--------- 2 files changed, 15 insertions(+), 12 deletions(-) diff --git a/src/llama_stack/providers/remote/vector_io/milvus/milvus.py b/src/llama_stack/providers/remote/vector_io/milvus/milvus.py index fdee58206..b84f1df81 100644 --- a/src/llama_stack/providers/remote/vector_io/milvus/milvus.py +++ b/src/llama_stack/providers/remote/vector_io/milvus/milvus.py @@ -63,7 +63,8 @@ class MilvusIndex(EmbeddingIndex): async def delete(self): try: - if await self.client.has_collection(self.collection_name): + collections = await self.client.list_collections() + if self.collection_name in collections: await self.client.drop_collection(collection_name=self.collection_name) except Exception as e: logger.warning(f"Failed to check or delete collection {self.collection_name}: {e}") @@ -74,7 +75,8 @@ class MilvusIndex(EmbeddingIndex): ) try: - collection_exists = await self.client.has_collection(self.collection_name) + collections = await self.client.list_collections() + collection_exists = self.collection_name in collections except Exception as e: logger.error(f"Failed to check collection existence: {self.collection_name} ({e})") # If it's an event loop issue, try to recreate the client @@ -83,7 +85,8 @@ class MilvusIndex(EmbeddingIndex): if hasattr(self, "_parent_adapter"): await self._parent_adapter._recreate_client() - collection_exists = await self.client.has_collection(self.collection_name) + collections = await self.client.list_collections() + collection_exists = self.collection_name in collections else: # Assume collection doesn't exist if we can't check collection_exists = False diff --git a/tests/unit/providers/vector_io/remote/test_milvus.py b/tests/unit/providers/vector_io/remote/test_milvus.py index 25374e617..7020900d7 100644 --- a/tests/unit/providers/vector_io/remote/test_milvus.py +++ b/tests/unit/providers/vector_io/remote/test_milvus.py @@ -40,7 +40,7 @@ async def mock_milvus_client() -> MagicMock: """Create a mock Milvus client with common method behaviors.""" client = MagicMock() - client.has_collection = AsyncMock(return_value=False) # Initially no collection + client.list_collections = AsyncMock(return_value=[]) # Initially no collections client.create_collection = AsyncMock(return_value=None) client.drop_collection = AsyncMock(return_value=None) @@ -103,7 +103,7 @@ async def milvus_index(mock_milvus_client): async def test_add_chunks(milvus_index, sample_chunks, sample_embeddings, mock_milvus_client): # Setup: collection doesn't exist initially, then exists after creation - mock_milvus_client.has_collection.side_effect = [False, True] + mock_milvus_client.list_collections.side_effect = [[], ["test_collection"]] await milvus_index.add_chunks(sample_chunks, sample_embeddings) @@ -120,7 +120,7 @@ async def test_query_chunks_vector( milvus_index, sample_chunks, sample_embeddings, embedding_dimension, mock_milvus_client ): # Setup: Add chunks first - mock_milvus_client.has_collection.return_value = True + mock_milvus_client.list_collections.return_value = ["test_collection"] await milvus_index.add_chunks(sample_chunks, sample_embeddings) # Test vector search @@ -133,7 +133,7 @@ async def test_query_chunks_vector( async def test_query_chunks_keyword_search(milvus_index, sample_chunks, sample_embeddings, mock_milvus_client): - mock_milvus_client.has_collection.return_value = True + mock_milvus_client.list_collections.return_value = ["test_collection"] await milvus_index.add_chunks(sample_chunks, sample_embeddings) # Test keyword search @@ -146,7 +146,7 @@ async def test_query_chunks_keyword_search(milvus_index, sample_chunks, sample_e async def test_bm25_fallback_to_simple_search(milvus_index, sample_chunks, sample_embeddings, mock_milvus_client): """Test that when BM25 search fails, the system falls back to simple text search.""" - mock_milvus_client.has_collection.return_value = True + mock_milvus_client.list_collections.return_value = ["test_collection"] await milvus_index.add_chunks(sample_chunks, sample_embeddings) # Force BM25 search to fail @@ -188,7 +188,7 @@ async def test_bm25_fallback_to_simple_search(milvus_index, sample_chunks, sampl async def test_delete_collection(milvus_index, mock_milvus_client): # Test collection deletion - mock_milvus_client.has_collection.return_value = True + mock_milvus_client.list_collections.return_value = ["test_collection"] await milvus_index.delete() @@ -199,7 +199,7 @@ async def test_query_hybrid_search_rrf( milvus_index, sample_chunks, sample_embeddings, embedding_dimension, mock_milvus_client ): """Test hybrid search with RRF reranker.""" - mock_milvus_client.has_collection.return_value = True + mock_milvus_client.list_collections.return_value = ["test_collection"] await milvus_index.add_chunks(sample_chunks, sample_embeddings) # Mock hybrid search results @@ -251,7 +251,7 @@ async def test_query_hybrid_search_weighted( milvus_index, sample_chunks, sample_embeddings, embedding_dimension, mock_milvus_client ): """Test hybrid search with weighted reranker.""" - mock_milvus_client.has_collection.return_value = True + mock_milvus_client.list_collections.return_value = ["test_collection"] await milvus_index.add_chunks(sample_chunks, sample_embeddings) # Mock hybrid search results @@ -297,7 +297,7 @@ async def test_query_hybrid_search_default_rrf( milvus_index, sample_chunks, sample_embeddings, embedding_dimension, mock_milvus_client ): """Test hybrid search with default RRF reranker (no reranker_type specified).""" - mock_milvus_client.has_collection.return_value = True + mock_milvus_client.list_collections.return_value = ["test_collection"] await milvus_index.add_chunks(sample_chunks, sample_embeddings) # Mock hybrid search results From 3e8291c1b94be537314d1d30371b7a98bf5b2a5c Mon Sep 17 00:00:00 2001 From: Mustafa Elbehery Date: Mon, 3 Nov 2025 23:20:39 +0100 Subject: [PATCH 7/9] fix: resolve rebase conflicts Signed-off-by: Mustafa Elbehery --- tests/unit/providers/vector_io/conftest.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/unit/providers/vector_io/conftest.py b/tests/unit/providers/vector_io/conftest.py index 7426219c1..eef99e718 100644 --- a/tests/unit/providers/vector_io/conftest.py +++ b/tests/unit/providers/vector_io/conftest.py @@ -9,7 +9,6 @@ from unittest.mock import AsyncMock, MagicMock, patch import numpy as np import pytest -from chromadb import PersistentClient from pymilvus import AsyncMilvusClient, connections from llama_stack.apis.vector_dbs import VectorDB From c6bf292f07975154eb7f26ef6d9bf160f3bb6b41 Mon Sep 17 00:00:00 2001 From: Mustafa Elbehery Date: Mon, 3 Nov 2025 23:38:44 +0100 Subject: [PATCH 8/9] fix(unittest): add pymilvus and milvus-lite to unit dep group This commit resolves unittest-3.12 issues. - It adds pymilvus and milvus-lite to unit dep group - It rename VectorDB to VectorStore Signed-off-by: Mustafa Elbehery --- pyproject.toml | 2 ++ tests/conftest.py | 8 ++++++++ tests/unit/providers/vector_io/conftest.py | 5 ++--- uv.lock | 6 +++++- 4 files changed, 17 insertions(+), 4 deletions(-) create mode 100644 tests/conftest.py diff --git a/pyproject.toml b/pyproject.toml index e6808af8a..6b2f9a585 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -123,6 +123,8 @@ unit = [ "together", "coverage", "moto[s3]>=5.1.10", + "pymilvus>=2.6.1", + "milvus-lite>=2.5.0", ] # These are the core dependencies required for running integration tests. They are shared across all # providers. If a provider requires additional dependencies, please add them to your environment diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 000000000..fa9bd9912 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,8 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# This file intentionally left empty - pytest will auto-discover conftest.py files +# in subdirectories and load their fixtures automatically. diff --git a/tests/unit/providers/vector_io/conftest.py b/tests/unit/providers/vector_io/conftest.py index eef99e718..9d3ddb5cb 100644 --- a/tests/unit/providers/vector_io/conftest.py +++ b/tests/unit/providers/vector_io/conftest.py @@ -11,7 +11,6 @@ import numpy as np import pytest from pymilvus import AsyncMilvusClient, connections -from llama_stack.apis.vector_dbs import VectorDB from llama_stack.apis.vector_io import Chunk, ChunkMetadata, QueryChunksResponse from llama_stack.apis.vector_stores import VectorStore from llama_stack.core.storage.datatypes import KVStoreReference, SqliteKVStoreConfig @@ -205,8 +204,8 @@ async def milvus_vec_adapter(milvus_vec_db_path, mock_inference_api): files_api=None, ) await adapter.initialize() - await adapter.register_vector_db( - VectorDB( + await adapter.register_vector_store( + VectorStore( identifier=adapter.metadata_collection_name, provider_id="test_provider", embedding_model="test_model", diff --git a/uv.lock b/uv.lock index f1808f005..063ca30a8 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 2 +revision = 3 requires-python = ">=3.12" resolution-markers = [ "(python_full_version >= '3.13' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version >= '3.13' and sys_platform != 'darwin' and sys_platform != 'linux')", @@ -2074,9 +2074,11 @@ unit = [ { name = "faiss-cpu" }, { name = "litellm" }, { name = "mcp" }, + { name = "milvus-lite" }, { name = "moto", extra = ["s3"] }, { name = "ollama" }, { name = "psycopg2-binary" }, + { name = "pymilvus" }, { name = "pypdf" }, { name = "sqlalchemy", extra = ["asyncio"] }, { name = "sqlite-vec" }, @@ -2216,9 +2218,11 @@ unit = [ { name = "faiss-cpu" }, { name = "litellm" }, { name = "mcp" }, + { name = "milvus-lite", specifier = ">=2.5.0" }, { name = "moto", extras = ["s3"], specifier = ">=5.1.10" }, { name = "ollama" }, { name = "psycopg2-binary", specifier = ">=2.9.0" }, + { name = "pymilvus", specifier = ">=2.6.1" }, { name = "pypdf", specifier = ">=6.1.3" }, { name = "sqlalchemy" }, { name = "sqlalchemy", extras = ["asyncio"], specifier = ">=2.0.41" }, From a6eed997906cf9298b685ec121f9b31cfd19e20a Mon Sep 17 00:00:00 2001 From: Mustafa Elbehery Date: Mon, 3 Nov 2025 23:58:15 +0100 Subject: [PATCH 9/9] fix(unittest): add required chunk_id field to Milvus test mock data The Chunk Pydantic model requires chunk_id as a mandatory field, but the mock data in test_milvus.py was missing this field in chunk_content objects, causing 6 test failures with ValidationError. Signed-off-by: Mustafa Elbehery --- .../providers/vector_io/remote/test_milvus.py | 74 ++++++++++++++++--- 1 file changed, 62 insertions(+), 12 deletions(-) diff --git a/tests/unit/providers/vector_io/remote/test_milvus.py b/tests/unit/providers/vector_io/remote/test_milvus.py index 7020900d7..24e771a96 100644 --- a/tests/unit/providers/vector_io/remote/test_milvus.py +++ b/tests/unit/providers/vector_io/remote/test_milvus.py @@ -52,12 +52,24 @@ async def mock_milvus_client() -> MagicMock: { "id": 0, "distance": 0.1, - "entity": {"chunk_content": {"content": "mock chunk 1", "metadata": {"document_id": "doc1"}}}, + "entity": { + "chunk_content": { + "chunk_id": "chunk1", + "content": "mock chunk 1", + "metadata": {"document_id": "doc1"}, + } + }, }, { "id": 1, "distance": 0.2, - "entity": {"chunk_content": {"content": "mock chunk 2", "metadata": {"document_id": "doc2"}}}, + "entity": { + "chunk_content": { + "chunk_id": "chunk2", + "content": "mock chunk 2", + "metadata": {"document_id": "doc2"}, + } + }, }, ] ] @@ -68,17 +80,17 @@ async def mock_milvus_client() -> MagicMock: return_value=[ { "chunk_id": "chunk1", - "chunk_content": {"content": "mock chunk 1", "metadata": {"document_id": "doc1"}}, + "chunk_content": {"chunk_id": "chunk1", "content": "mock chunk 1", "metadata": {"document_id": "doc1"}}, "score": 0.9, }, { "chunk_id": "chunk2", - "chunk_content": {"content": "mock chunk 2", "metadata": {"document_id": "doc2"}}, + "chunk_content": {"chunk_id": "chunk2", "content": "mock chunk 2", "metadata": {"document_id": "doc2"}}, "score": 0.8, }, { "chunk_id": "chunk3", - "chunk_content": {"content": "mock chunk 3", "metadata": {"document_id": "doc3"}}, + "chunk_content": {"chunk_id": "chunk3", "content": "mock chunk 3", "metadata": {"document_id": "doc3"}}, "score": 0.7, }, ] @@ -156,11 +168,19 @@ async def test_bm25_fallback_to_simple_search(milvus_index, sample_chunks, sampl mock_milvus_client.query.return_value = [ { "chunk_id": "chunk1", - "chunk_content": {"content": "Python programming language", "metadata": {"document_id": "doc1"}}, + "chunk_content": { + "chunk_id": "chunk1", + "content": "Python programming language", + "metadata": {"document_id": "doc1"}, + }, }, { "chunk_id": "chunk2", - "chunk_content": {"content": "Machine learning algorithms", "metadata": {"document_id": "doc2"}}, + "chunk_content": { + "chunk_id": "chunk2", + "content": "Machine learning algorithms", + "metadata": {"document_id": "doc2"}, + }, }, ] @@ -208,12 +228,24 @@ async def test_query_hybrid_search_rrf( { "id": 0, "distance": 0.1, - "entity": {"chunk_content": {"content": "mock chunk 1", "metadata": {"document_id": "doc1"}}}, + "entity": { + "chunk_content": { + "chunk_id": "chunk1", + "content": "mock chunk 1", + "metadata": {"document_id": "doc1"}, + } + }, }, { "id": 1, "distance": 0.2, - "entity": {"chunk_content": {"content": "mock chunk 2", "metadata": {"document_id": "doc2"}}}, + "entity": { + "chunk_content": { + "chunk_id": "chunk2", + "content": "mock chunk 2", + "metadata": {"document_id": "doc2"}, + } + }, }, ] ] @@ -260,12 +292,24 @@ async def test_query_hybrid_search_weighted( { "id": 0, "distance": 0.1, - "entity": {"chunk_content": {"content": "mock chunk 1", "metadata": {"document_id": "doc1"}}}, + "entity": { + "chunk_content": { + "chunk_id": "chunk1", + "content": "mock chunk 1", + "metadata": {"document_id": "doc1"}, + } + }, }, { "id": 1, "distance": 0.2, - "entity": {"chunk_content": {"content": "mock chunk 2", "metadata": {"document_id": "doc2"}}}, + "entity": { + "chunk_content": { + "chunk_id": "chunk2", + "content": "mock chunk 2", + "metadata": {"document_id": "doc2"}, + } + }, }, ] ] @@ -306,7 +350,13 @@ async def test_query_hybrid_search_default_rrf( { "id": 0, "distance": 0.1, - "entity": {"chunk_content": {"content": "mock chunk 1", "metadata": {"document_id": "doc1"}}}, + "entity": { + "chunk_content": { + "chunk_id": "chunk1", + "content": "mock chunk 1", + "metadata": {"document_id": "doc1"}, + } + }, }, ] ]