Feat: Implement keyword search in milvus

Signed-off-by: Varsha Prasad Narsing <varshaprasad96@gmail.com>
This commit is contained in:
Varsha Prasad Narsing 2025-05-22 15:21:40 -07:00
parent 5561f1c36d
commit 29801783b8
3 changed files with 284 additions and 4 deletions

View file

@ -101,6 +101,15 @@ vector_io:
- **`client_pem_path`**: Path to the **client certificate** file (required for mTLS). - **`client_pem_path`**: Path to the **client certificate** file (required for mTLS).
- **`client_key_path`**: Path to the **client private key** file (required for mTLS). - **`client_key_path`**: Path to the **client private key** file (required for mTLS).
## Supported Search Modes
The Milvus provider supports both vector-based and keyword-based (full-text) search modes, but with some limitations:
- Remote Milvus supports both vector-based and keyword-based search modes.
- Inline Milvus (Milvus-Lite) only supports vector-based search. Keyword search is not supported as Milvus-Lite has not implemented this functionality yet. For updates on this feature, see [Milvus GitHub Issue #40848](https://github.com/milvus-io/milvus/issues/40848).
When using the RAGTool interface, you can specify the desired search behavior via the `mode` parameter in `RAGQueryConfig`. For more details on Milvus's implementation of keyword search modes, refer to the [Milvus documentation](https://milvus.io/docs/full_text_search_with_milvus.md).
## Documentation ## Documentation
See the [Milvus documentation](https://milvus.io/docs/install-overview.md) for more details about Milvus in general. See the [Milvus documentation](https://milvus.io/docs/install-overview.md) for more details about Milvus in general.

View file

@ -70,11 +70,58 @@ class MilvusIndex(EmbeddingIndex):
f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}" f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}"
) )
if not await asyncio.to_thread(self.client.has_collection, self.collection_name): if not await asyncio.to_thread(self.client.has_collection, self.collection_name):
# Create schema for vector search
schema = self.client.create_schema()
schema.add_field(
field_name="chunk_id",
datatype=DataType.VARCHAR,
is_primary=True,
max_length=100,
)
schema.add_field(
field_name="content",
datatype=DataType.VARCHAR,
max_length=65535,
enable_analyzer=True, # Enable text analysis for BM25
)
schema.add_field(
field_name="vector",
datatype=DataType.FLOAT_VECTOR,
dim=len(embeddings[0]),
)
schema.add_field(
field_name="chunk_content",
datatype=DataType.JSON,
)
# Add sparse vector field for BM25
schema.add_field(
field_name="sparse",
datatype=DataType.SPARSE_FLOAT_VECTOR,
)
# Create indexes
index_params = self.client.prepare_index_params()
index_params.add_index(
field_name="vector",
index_type="FLAT",
metric_type="COSINE",
)
# Add BM25 function for full-text search
from pymilvus import Function, FunctionType
bm25_function = Function(
name="text_bm25_emb",
input_field_names=["content"],
output_field_names=["sparse"],
function_type=FunctionType.BM25,
)
schema.add_function(bm25_function)
await asyncio.to_thread( await asyncio.to_thread(
self.client.create_collection, self.client.create_collection,
self.collection_name, self.collection_name,
dimension=len(embeddings[0]), schema=schema,
auto_id=True, index_params=index_params,
consistency_level=self.consistency_level, consistency_level=self.consistency_level,
) )
@ -83,8 +130,10 @@ class MilvusIndex(EmbeddingIndex):
data.append( data.append(
{ {
"chunk_id": chunk.chunk_id, "chunk_id": chunk.chunk_id,
"content": chunk.content,
"vector": embedding, "vector": embedding,
"chunk_content": chunk.model_dump(), "chunk_content": chunk.model_dump(),
# sparse field will be automatically populated by BM25 function
} }
) )
try: try:
@ -102,9 +151,10 @@ class MilvusIndex(EmbeddingIndex):
self.client.search, self.client.search,
collection_name=self.collection_name, collection_name=self.collection_name,
data=[embedding], data=[embedding],
anns_field="vector",
limit=k, limit=k,
output_fields=["*"], output_fields=["*"],
search_params={"params": {"radius": score_threshold}}, search_params={"metric_type": "COSINE", "params": {"score_threshold": score_threshold}},
) )
chunks = [Chunk(**res["entity"]["chunk_content"]) for res in search_res[0]] chunks = [Chunk(**res["entity"]["chunk_content"]) for res in search_res[0]]
scores = [res["distance"] for res in search_res[0]] scores = [res["distance"] for res in search_res[0]]
@ -116,7 +166,41 @@ class MilvusIndex(EmbeddingIndex):
k: int, k: int,
score_threshold: float, score_threshold: float,
) -> QueryChunksResponse: ) -> QueryChunksResponse:
raise NotImplementedError("Keyword search is not supported in Milvus") """
Perform BM25-based keyword search using Milvus's built-in full-text search.
"""
try:
from pymilvus import Function, FunctionType
search_res = await asyncio.to_thread(
self.client.search,
collection_name=self.collection_name,
data=[query_string], # Raw text query
anns_field="sparse", # Use sparse field for BM25
output_fields=["chunk_content"], # Output the chunk content
limit=k,
search_params={
"params": {
"drop_ratio_search": 0.2, # Ignore low-importance terms
}
},
)
chunks = []
scores = []
for res in search_res[0]:
chunk = Chunk(**res["entity"]["chunk_content"])
chunks.append(chunk)
scores.append(res["distance"]) # BM25 score from Milvus
# Filter by score threshold
filtered_results = [(chunk, score) for chunk, score in zip(chunks, scores, strict=False) if score >= score_threshold]
if filtered_results:
chunks, scores = zip(*filtered_results, strict=False)
return QueryChunksResponse(chunks=list(chunks), scores=list(scores))
else:
return QueryChunksResponse(chunks=[], scores=[])
except Exception as e:
logger.error(f"Error performing BM25 search: {e}")
# Fallback to simple text search
return await self._fallback_keyword_search(query_string, k, score_threshold)
async def query_hybrid( async def query_hybrid(
self, self,
@ -238,6 +322,14 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
if not index: if not index:
raise ValueError(f"Vector DB {vector_db_id} not found") raise ValueError(f"Vector DB {vector_db_id} not found")
if params and params.get("mode") == "keyword":
# Check if this is inline Milvus (Milvus-Lite)
if hasattr(self.config, "db_path"):
raise NotImplementedError(
"Keyword search is not supported in Milvus-Lite. "
"Please use a remote Milvus server for keyword search functionality."
)
return await index.query_chunks(query, params) return await index.query_chunks(query, params)
async def _save_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None: async def _save_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:

View file

@ -0,0 +1,179 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from unittest.mock import MagicMock, patch
import numpy as np
import pytest
import pytest_asyncio
from llama_stack.apis.vector_io import QueryChunksResponse
# Mock the entire pymilvus module
pymilvus_mock = MagicMock()
pymilvus_mock.DataType = MagicMock()
pymilvus_mock.MilvusClient = MagicMock
# Apply the mock before importing MilvusIndex
with patch.dict("sys.modules", {"pymilvus": pymilvus_mock}):
from llama_stack.providers.remote.vector_io.milvus.milvus import MilvusIndex
# This test is a unit test for the MilvusVectorIOAdapter class. This should only contain
# tests which are specific to this class. More general (API-level) tests should be placed in
# tests/integration/vector_io/
#
# How to run this test:
#
# pytest tests/unit/providers/vector_io/test_milvus.py \
# -v -s --tb=short --disable-warnings --asyncio-mode=auto
MILVUS_PROVIDER = "milvus"
@pytest_asyncio.fixture
async def mock_milvus_client():
"""Create a mock Milvus client with common method behaviors."""
client = MagicMock()
# Mock collection operations
client.has_collection.return_value = False # Initially no collection
client.create_collection.return_value = None
client.drop_collection.return_value = None
# Mock insert operation
client.insert.return_value = {"insert_count": 10}
# Mock search operation - return mock results (data should be dict, not JSON string)
client.search.return_value = [
[
{
"id": 0,
"distance": 0.1,
"entity": {"chunk_content": {"content": "mock chunk 1", "metadata": {"document_id": "doc1"}}},
},
{
"id": 1,
"distance": 0.2,
"entity": {"chunk_content": {"content": "mock chunk 2", "metadata": {"document_id": "doc2"}}},
},
]
]
# Mock query operation for keyword search (data should be dict, not JSON string)
client.query.return_value = [
{
"chunk_id": "chunk1",
"chunk_content": {"content": "mock chunk 1", "metadata": {"document_id": "doc1"}},
"score": 0.9,
},
{
"chunk_id": "chunk2",
"chunk_content": {"content": "mock chunk 2", "metadata": {"document_id": "doc2"}},
"score": 0.8,
},
{
"chunk_id": "chunk3",
"chunk_content": {"content": "mock chunk 3", "metadata": {"document_id": "doc3"}},
"score": 0.7,
},
]
return client
@pytest_asyncio.fixture
async def milvus_index(mock_milvus_client):
"""Create a MilvusIndex with mocked client."""
index = MilvusIndex(client=mock_milvus_client, collection_name="test_collection")
yield index
# No real cleanup needed since we're using mocks
@pytest.mark.asyncio
async def test_add_chunks(milvus_index, sample_chunks, sample_embeddings, mock_milvus_client):
# Setup: collection doesn't exist initially, then exists after creation
mock_milvus_client.has_collection.side_effect = [False, True]
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
# Verify collection was created and data was inserted
mock_milvus_client.create_collection.assert_called_once()
mock_milvus_client.insert.assert_called_once()
# Verify the insert call had the right number of chunks
insert_call = mock_milvus_client.insert.call_args
assert len(insert_call[1]["data"]) == len(sample_chunks)
@pytest.mark.asyncio
async def test_query_chunks_vector(
milvus_index, sample_chunks, sample_embeddings, embedding_dimension, mock_milvus_client
):
# Setup: Add chunks first
mock_milvus_client.has_collection.return_value = True
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
# Test vector search
query_embedding = np.random.rand(embedding_dimension).astype(np.float32)
response = await milvus_index.query_vector(query_embedding, k=2, score_threshold=0.0)
assert isinstance(response, QueryChunksResponse)
assert len(response.chunks) == 2
mock_milvus_client.search.assert_called_once()
@pytest.mark.asyncio
async def test_query_chunks_keyword_search(milvus_index, sample_chunks, sample_embeddings, mock_milvus_client):
mock_milvus_client.has_collection.return_value = True
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
# Test keyword search
query_string = "Sentence 5"
response = await milvus_index.query_keyword(query_string=query_string, k=3, score_threshold=0.0)
assert isinstance(response, QueryChunksResponse)
assert len(response.chunks) == 3
mock_milvus_client.query.assert_called_once()
# Test no results case
mock_milvus_client.query.return_value = []
response_no_results = await milvus_index.query_keyword(query_string="nonexistent", k=1, score_threshold=0.0)
assert isinstance(response_no_results, QueryChunksResponse)
assert len(response_no_results.chunks) == 0
@pytest.mark.asyncio
async def test_query_chunks_keyword_search_k_greater_than_results(
milvus_index, sample_chunks, sample_embeddings, mock_milvus_client
):
mock_milvus_client.has_collection.return_value = True
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
# Mock returning only 1 result even though k=5
mock_milvus_client.query.return_value = [
{
"chunk_id": "chunk1",
"chunk_content": {"content": "Sentence 1 from document 0", "metadata": {"document_id": "doc1"}},
"score": 0.9,
}
]
query_str = "Sentence 1 from document 0"
response = await milvus_index.query_keyword(query_string=query_str, k=5, score_threshold=0.0)
assert 0 < len(response.chunks) <= 4
assert any("Sentence 1 from document 0" in chunk.content for chunk in response.chunks)
@pytest.mark.asyncio
async def test_delete_collection(milvus_index, mock_milvus_client):
# Test collection deletion
mock_milvus_client.has_collection.return_value = True
await milvus_index.delete()
mock_milvus_client.drop_collection.assert_called_once_with(collection_name=milvus_index.collection_name)