Fix formatting and finalize Milvus BM25 integration

Signed-off-by: Varsha Prasad Narsing <varshaprasad96@gmail.com>
This commit is contained in:
Varsha Prasad Narsing 2025-07-07 12:48:40 -07:00
parent 29801783b8
commit 0108bb1aa5
3 changed files with 62 additions and 39 deletions

View file

@ -101,15 +101,6 @@ vector_io:
- **`client_pem_path`**: Path to the **client certificate** file (required for mTLS). - **`client_pem_path`**: Path to the **client certificate** file (required for mTLS).
- **`client_key_path`**: Path to the **client private key** file (required for mTLS). - **`client_key_path`**: Path to the **client private key** file (required for mTLS).
## Supported Search Modes
The Milvus provider supports both vector-based and keyword-based (full-text) search modes, but with some limitations:
- Remote Milvus supports both vector-based and keyword-based search modes.
- Inline Milvus (Milvus-Lite) only supports vector-based search. Keyword search is not supported as Milvus-Lite has not implemented this functionality yet. For updates on this feature, see [Milvus GitHub Issue #40848](https://github.com/milvus-io/milvus/issues/40848).
When using the RAGTool interface, you can specify the desired search behavior via the `mode` parameter in `RAGQueryConfig`. For more details on Milvus's implementation of keyword search modes, refer to the [Milvus documentation](https://milvus.io/docs/full_text_search_with_milvus.md).
## Documentation ## Documentation
See the [Milvus documentation](https://milvus.io/docs/install-overview.md) for more details about Milvus in general. See the [Milvus documentation](https://milvus.io/docs/install-overview.md) for more details about Milvus in general.

View file

@ -12,7 +12,7 @@ import re
from typing import Any from typing import Any
from numpy.typing import NDArray from numpy.typing import NDArray
from pymilvus import DataType, MilvusClient from pymilvus import DataType, Function, FunctionType, MilvusClient
from llama_stack.apis.files.files import Files from llama_stack.apis.files.files import Files
from llama_stack.apis.inference import Inference, InterleavedContent from llama_stack.apis.inference import Inference, InterleavedContent
@ -108,7 +108,6 @@ class MilvusIndex(EmbeddingIndex):
) )
# Add BM25 function for full-text search # Add BM25 function for full-text search
from pymilvus import Function, FunctionType
bm25_function = Function( bm25_function = Function(
name="text_bm25_emb", name="text_bm25_emb",
input_field_names=["content"], input_field_names=["content"],
@ -170,7 +169,7 @@ class MilvusIndex(EmbeddingIndex):
Perform BM25-based keyword search using Milvus's built-in full-text search. Perform BM25-based keyword search using Milvus's built-in full-text search.
""" """
try: try:
from pymilvus import Function, FunctionType # Use Milvus's built-in BM25 search
search_res = await asyncio.to_thread( search_res = await asyncio.to_thread(
self.client.search, self.client.search,
collection_name=self.collection_name, collection_name=self.collection_name,
@ -184,24 +183,46 @@ class MilvusIndex(EmbeddingIndex):
} }
}, },
) )
chunks = [] chunks = []
scores = [] scores = []
for res in search_res[0]: for res in search_res[0]:
chunk = Chunk(**res["entity"]["chunk_content"]) chunk = Chunk(**res["entity"]["chunk_content"])
chunks.append(chunk) chunks.append(chunk)
scores.append(res["distance"]) # BM25 score from Milvus scores.append(res["distance"]) # BM25 score from Milvus
# Filter by score threshold # Filter by score threshold
filtered_results = [(chunk, score) for chunk, score in zip(chunks, scores, strict=False) if score >= score_threshold] filtered_chunks = [chunk for chunk, score in zip(chunks, scores, strict=False) if score >= score_threshold]
if filtered_results: filtered_scores = [score for score in scores if score >= score_threshold]
chunks, scores = zip(*filtered_results, strict=False)
return QueryChunksResponse(chunks=list(chunks), scores=list(scores)) return QueryChunksResponse(chunks=filtered_chunks, scores=filtered_scores)
else:
return QueryChunksResponse(chunks=[], scores=[])
except Exception as e: except Exception as e:
logger.error(f"Error performing BM25 search: {e}") logger.error(f"Error performing BM25 search: {e}")
# Fallback to simple text search # Fallback to simple text search
return await self._fallback_keyword_search(query_string, k, score_threshold) return await self._fallback_keyword_search(query_string, k, score_threshold)
async def _fallback_keyword_search(
self,
query_string: str,
k: int,
score_threshold: float,
) -> QueryChunksResponse:
"""
Fallback to simple text search when BM25 search is not available.
"""
# Simple text search using content field
search_res = await asyncio.to_thread(
self.client.query,
collection_name=self.collection_name,
filter=f'content like "%{query_string}%"',
output_fields=["*"],
limit=k,
)
chunks = [Chunk(**res["chunk_content"]) for res in search_res]
scores = [1.0] * len(chunks) # Simple binary score for text search
return QueryChunksResponse(chunks=chunks, scores=scores)
async def query_hybrid( async def query_hybrid(
self, self,
embedding: NDArray, embedding: NDArray,

View file

@ -132,41 +132,52 @@ async def test_query_chunks_keyword_search(milvus_index, sample_chunks, sample_e
# Test keyword search # Test keyword search
query_string = "Sentence 5" query_string = "Sentence 5"
response = await milvus_index.query_keyword(query_string=query_string, k=3, score_threshold=0.0) response = await milvus_index.query_keyword(query_string=query_string, k=2, score_threshold=0.0)
assert isinstance(response, QueryChunksResponse) assert isinstance(response, QueryChunksResponse)
assert len(response.chunks) == 3 assert len(response.chunks) == 2
mock_milvus_client.query.assert_called_once()
# Test no results case
mock_milvus_client.query.return_value = []
response_no_results = await milvus_index.query_keyword(query_string="nonexistent", k=1, score_threshold=0.0)
assert isinstance(response_no_results, QueryChunksResponse)
assert len(response_no_results.chunks) == 0
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_query_chunks_keyword_search_k_greater_than_results( async def test_bm25_fallback_to_simple_search(milvus_index, sample_chunks, sample_embeddings, mock_milvus_client):
milvus_index, sample_chunks, sample_embeddings, mock_milvus_client """Test that when BM25 search fails, the system falls back to simple text search."""
):
mock_milvus_client.has_collection.return_value = True mock_milvus_client.has_collection.return_value = True
await milvus_index.add_chunks(sample_chunks, sample_embeddings) await milvus_index.add_chunks(sample_chunks, sample_embeddings)
# Mock returning only 1 result even though k=5 # Force BM25 search to fail
mock_milvus_client.search.side_effect = Exception("BM25 search not available")
# Mock simple text search results
mock_milvus_client.query.return_value = [ mock_milvus_client.query.return_value = [
{ {
"chunk_id": "chunk1", "chunk_id": "chunk1",
"chunk_content": {"content": "Sentence 1 from document 0", "metadata": {"document_id": "doc1"}}, "chunk_content": {"content": "Python programming language", "metadata": {"document_id": "doc1"}},
"score": 0.9, },
} {
"chunk_id": "chunk2",
"chunk_content": {"content": "Machine learning algorithms", "metadata": {"document_id": "doc2"}},
},
] ]
query_str = "Sentence 1 from document 0" # Test keyword search that should fall back to simple text search
response = await milvus_index.query_keyword(query_string=query_str, k=5, score_threshold=0.0) query_string = "Python"
response = await milvus_index.query_keyword(query_string=query_string, k=3, score_threshold=0.0)
assert 0 < len(response.chunks) <= 4 # Verify response structure
assert any("Sentence 1 from document 0" in chunk.content for chunk in response.chunks) assert isinstance(response, QueryChunksResponse)
assert len(response.chunks) > 0, "Fallback search should return results"
# Verify that simple text search was used (query method called instead of search)
mock_milvus_client.query.assert_called_once()
mock_milvus_client.search.assert_called_once() # Called once but failed
# Verify the query filter contains the search term
query_call_args = mock_milvus_client.query.call_args
assert "filter" in query_call_args[1], "Query should include filter for text search"
assert "Python" in query_call_args[1]["filter"], "Filter should contain the search term"
# Verify all returned chunks have score 1.0 (simple binary scoring)
assert all(score == 1.0 for score in response.scores), "Simple text search should use binary scoring"
@pytest.mark.asyncio @pytest.mark.asyncio