mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
Fix formatting and finalize Milvus BM25 integration
Signed-off-by: Varsha Prasad Narsing <varshaprasad96@gmail.com>
This commit is contained in:
parent
29801783b8
commit
0108bb1aa5
3 changed files with 62 additions and 39 deletions
|
@ -101,15 +101,6 @@ vector_io:
|
||||||
- **`client_pem_path`**: Path to the **client certificate** file (required for mTLS).
|
- **`client_pem_path`**: Path to the **client certificate** file (required for mTLS).
|
||||||
- **`client_key_path`**: Path to the **client private key** file (required for mTLS).
|
- **`client_key_path`**: Path to the **client private key** file (required for mTLS).
|
||||||
|
|
||||||
## Supported Search Modes
|
|
||||||
|
|
||||||
The Milvus provider supports both vector-based and keyword-based (full-text) search modes, but with some limitations:
|
|
||||||
|
|
||||||
- Remote Milvus supports both vector-based and keyword-based search modes.
|
|
||||||
- Inline Milvus (Milvus-Lite) only supports vector-based search. Keyword search is not supported as Milvus-Lite has not implemented this functionality yet. For updates on this feature, see [Milvus GitHub Issue #40848](https://github.com/milvus-io/milvus/issues/40848).
|
|
||||||
|
|
||||||
When using the RAGTool interface, you can specify the desired search behavior via the `mode` parameter in `RAGQueryConfig`. For more details on Milvus's implementation of keyword search modes, refer to the [Milvus documentation](https://milvus.io/docs/full_text_search_with_milvus.md).
|
|
||||||
|
|
||||||
## Documentation
|
## Documentation
|
||||||
See the [Milvus documentation](https://milvus.io/docs/install-overview.md) for more details about Milvus in general.
|
See the [Milvus documentation](https://milvus.io/docs/install-overview.md) for more details about Milvus in general.
|
||||||
|
|
||||||
|
|
|
@ -12,7 +12,7 @@ import re
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from numpy.typing import NDArray
|
from numpy.typing import NDArray
|
||||||
from pymilvus import DataType, MilvusClient
|
from pymilvus import DataType, Function, FunctionType, MilvusClient
|
||||||
|
|
||||||
from llama_stack.apis.files.files import Files
|
from llama_stack.apis.files.files import Files
|
||||||
from llama_stack.apis.inference import Inference, InterleavedContent
|
from llama_stack.apis.inference import Inference, InterleavedContent
|
||||||
|
@ -108,7 +108,6 @@ class MilvusIndex(EmbeddingIndex):
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add BM25 function for full-text search
|
# Add BM25 function for full-text search
|
||||||
from pymilvus import Function, FunctionType
|
|
||||||
bm25_function = Function(
|
bm25_function = Function(
|
||||||
name="text_bm25_emb",
|
name="text_bm25_emb",
|
||||||
input_field_names=["content"],
|
input_field_names=["content"],
|
||||||
|
@ -170,7 +169,7 @@ class MilvusIndex(EmbeddingIndex):
|
||||||
Perform BM25-based keyword search using Milvus's built-in full-text search.
|
Perform BM25-based keyword search using Milvus's built-in full-text search.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
from pymilvus import Function, FunctionType
|
# Use Milvus's built-in BM25 search
|
||||||
search_res = await asyncio.to_thread(
|
search_res = await asyncio.to_thread(
|
||||||
self.client.search,
|
self.client.search,
|
||||||
collection_name=self.collection_name,
|
collection_name=self.collection_name,
|
||||||
|
@ -184,24 +183,46 @@ class MilvusIndex(EmbeddingIndex):
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
chunks = []
|
chunks = []
|
||||||
scores = []
|
scores = []
|
||||||
for res in search_res[0]:
|
for res in search_res[0]:
|
||||||
chunk = Chunk(**res["entity"]["chunk_content"])
|
chunk = Chunk(**res["entity"]["chunk_content"])
|
||||||
chunks.append(chunk)
|
chunks.append(chunk)
|
||||||
scores.append(res["distance"]) # BM25 score from Milvus
|
scores.append(res["distance"]) # BM25 score from Milvus
|
||||||
|
|
||||||
# Filter by score threshold
|
# Filter by score threshold
|
||||||
filtered_results = [(chunk, score) for chunk, score in zip(chunks, scores, strict=False) if score >= score_threshold]
|
filtered_chunks = [chunk for chunk, score in zip(chunks, scores, strict=False) if score >= score_threshold]
|
||||||
if filtered_results:
|
filtered_scores = [score for score in scores if score >= score_threshold]
|
||||||
chunks, scores = zip(*filtered_results, strict=False)
|
|
||||||
return QueryChunksResponse(chunks=list(chunks), scores=list(scores))
|
return QueryChunksResponse(chunks=filtered_chunks, scores=filtered_scores)
|
||||||
else:
|
|
||||||
return QueryChunksResponse(chunks=[], scores=[])
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error performing BM25 search: {e}")
|
logger.error(f"Error performing BM25 search: {e}")
|
||||||
# Fallback to simple text search
|
# Fallback to simple text search
|
||||||
return await self._fallback_keyword_search(query_string, k, score_threshold)
|
return await self._fallback_keyword_search(query_string, k, score_threshold)
|
||||||
|
|
||||||
|
async def _fallback_keyword_search(
|
||||||
|
self,
|
||||||
|
query_string: str,
|
||||||
|
k: int,
|
||||||
|
score_threshold: float,
|
||||||
|
) -> QueryChunksResponse:
|
||||||
|
"""
|
||||||
|
Fallback to simple text search when BM25 search is not available.
|
||||||
|
"""
|
||||||
|
# Simple text search using content field
|
||||||
|
search_res = await asyncio.to_thread(
|
||||||
|
self.client.query,
|
||||||
|
collection_name=self.collection_name,
|
||||||
|
filter=f'content like "%{query_string}%"',
|
||||||
|
output_fields=["*"],
|
||||||
|
limit=k,
|
||||||
|
)
|
||||||
|
chunks = [Chunk(**res["chunk_content"]) for res in search_res]
|
||||||
|
scores = [1.0] * len(chunks) # Simple binary score for text search
|
||||||
|
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||||
|
|
||||||
async def query_hybrid(
|
async def query_hybrid(
|
||||||
self,
|
self,
|
||||||
embedding: NDArray,
|
embedding: NDArray,
|
||||||
|
|
|
@ -132,41 +132,52 @@ async def test_query_chunks_keyword_search(milvus_index, sample_chunks, sample_e
|
||||||
|
|
||||||
# Test keyword search
|
# Test keyword search
|
||||||
query_string = "Sentence 5"
|
query_string = "Sentence 5"
|
||||||
response = await milvus_index.query_keyword(query_string=query_string, k=3, score_threshold=0.0)
|
response = await milvus_index.query_keyword(query_string=query_string, k=2, score_threshold=0.0)
|
||||||
|
|
||||||
assert isinstance(response, QueryChunksResponse)
|
assert isinstance(response, QueryChunksResponse)
|
||||||
assert len(response.chunks) == 3
|
assert len(response.chunks) == 2
|
||||||
mock_milvus_client.query.assert_called_once()
|
|
||||||
|
|
||||||
# Test no results case
|
|
||||||
mock_milvus_client.query.return_value = []
|
|
||||||
response_no_results = await milvus_index.query_keyword(query_string="nonexistent", k=1, score_threshold=0.0)
|
|
||||||
|
|
||||||
assert isinstance(response_no_results, QueryChunksResponse)
|
|
||||||
assert len(response_no_results.chunks) == 0
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_query_chunks_keyword_search_k_greater_than_results(
|
async def test_bm25_fallback_to_simple_search(milvus_index, sample_chunks, sample_embeddings, mock_milvus_client):
|
||||||
milvus_index, sample_chunks, sample_embeddings, mock_milvus_client
|
"""Test that when BM25 search fails, the system falls back to simple text search."""
|
||||||
):
|
|
||||||
mock_milvus_client.has_collection.return_value = True
|
mock_milvus_client.has_collection.return_value = True
|
||||||
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
|
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
|
||||||
|
|
||||||
# Mock returning only 1 result even though k=5
|
# Force BM25 search to fail
|
||||||
|
mock_milvus_client.search.side_effect = Exception("BM25 search not available")
|
||||||
|
|
||||||
|
# Mock simple text search results
|
||||||
mock_milvus_client.query.return_value = [
|
mock_milvus_client.query.return_value = [
|
||||||
{
|
{
|
||||||
"chunk_id": "chunk1",
|
"chunk_id": "chunk1",
|
||||||
"chunk_content": {"content": "Sentence 1 from document 0", "metadata": {"document_id": "doc1"}},
|
"chunk_content": {"content": "Python programming language", "metadata": {"document_id": "doc1"}},
|
||||||
"score": 0.9,
|
},
|
||||||
}
|
{
|
||||||
|
"chunk_id": "chunk2",
|
||||||
|
"chunk_content": {"content": "Machine learning algorithms", "metadata": {"document_id": "doc2"}},
|
||||||
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
query_str = "Sentence 1 from document 0"
|
# Test keyword search that should fall back to simple text search
|
||||||
response = await milvus_index.query_keyword(query_string=query_str, k=5, score_threshold=0.0)
|
query_string = "Python"
|
||||||
|
response = await milvus_index.query_keyword(query_string=query_string, k=3, score_threshold=0.0)
|
||||||
|
|
||||||
assert 0 < len(response.chunks) <= 4
|
# Verify response structure
|
||||||
assert any("Sentence 1 from document 0" in chunk.content for chunk in response.chunks)
|
assert isinstance(response, QueryChunksResponse)
|
||||||
|
assert len(response.chunks) > 0, "Fallback search should return results"
|
||||||
|
|
||||||
|
# Verify that simple text search was used (query method called instead of search)
|
||||||
|
mock_milvus_client.query.assert_called_once()
|
||||||
|
mock_milvus_client.search.assert_called_once() # Called once but failed
|
||||||
|
|
||||||
|
# Verify the query filter contains the search term
|
||||||
|
query_call_args = mock_milvus_client.query.call_args
|
||||||
|
assert "filter" in query_call_args[1], "Query should include filter for text search"
|
||||||
|
assert "Python" in query_call_args[1]["filter"], "Filter should contain the search term"
|
||||||
|
|
||||||
|
# Verify all returned chunks have score 1.0 (simple binary scoring)
|
||||||
|
assert all(score == 1.0 for score in response.scores), "Simple text search should use binary scoring"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue