feat: Implement keyword search in milvus (#2231)
Some checks failed
SqlStore Integration Tests / test-postgres (3.13) (push) Failing after 7s
Integration Tests / discover-tests (push) Successful in 8s
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 10s
Test Llama Stack Build / build-custom-container-distribution (push) Failing after 6s
Test Llama Stack Build / build-ubi9-container-distribution (push) Failing after 6s
Vector IO Integration Tests / test-matrix (3.12, remote::chromadb) (push) Failing after 11s
Vector IO Integration Tests / test-matrix (3.13, remote::chromadb) (push) Failing after 9s
Vector IO Integration Tests / test-matrix (3.12, remote::pgvector) (push) Failing after 10s
Vector IO Integration Tests / test-matrix (3.13, remote::pgvector) (push) Failing after 8s
Test Llama Stack Build / generate-matrix (push) Successful in 8s
Python Package Build Test / build (3.13) (push) Failing after 6s
Unit Tests / unit-tests (3.12) (push) Failing after 6s
Unit Tests / unit-tests (3.13) (push) Failing after 6s
Vector IO Integration Tests / test-matrix (3.12, inline::milvus) (push) Failing after 13s
Vector IO Integration Tests / test-matrix (3.13, inline::milvus) (push) Failing after 12s
Vector IO Integration Tests / test-matrix (3.12, inline::sqlite-vec) (push) Failing after 15s
Test External Providers / test-external-providers (venv) (push) Failing after 9s
Test Llama Stack Build / build-single-provider (push) Failing after 11s
Vector IO Integration Tests / test-matrix (3.13, inline::sqlite-vec) (push) Failing after 14s
SqlStore Integration Tests / test-postgres (3.12) (push) Failing after 19s
Integration Tests / test-matrix (push) Failing after 8s
Test Llama Stack Build / build (push) Failing after 5s
Python Package Build Test / build (3.12) (push) Failing after 51s
Vector IO Integration Tests / test-matrix (3.13, inline::faiss) (push) Failing after 55s
Vector IO Integration Tests / test-matrix (3.12, inline::faiss) (push) Failing after 57s
Update ReadTheDocs / update-readthedocs (push) Failing after 50s
Pre-commit / pre-commit (push) Successful in 2m9s

# What does this PR do?
This PR adds the keyword search implementation for Milvus. Along with
the implementation for remote Milvus, the tests require us to start a
Milvus containers locally.

In order to verify the implementation, run:
```
pytest tests/unit/providers/vector_io/remote/test_milvus.py -v -s --tb=short --disable-warnings --asyncio-mode=auto
```

You can also test the changes using the below script:
```
#!/usr/bin/env python3
import asyncio
import os
import uuid
from typing import List

from llama_stack_client import (
    Agent, 
    AgentEventLogger, 
    LlamaStackClient, 
    RAGDocument
)


class MilvusRAGDemo:
    def __init__(self, base_url: str = "http://localhost:8321/"):
        self.client = LlamaStackClient(base_url=base_url)
        self.vector_db_id = f"milvus_rag_demo_{uuid.uuid4().hex[:8]}"
        self.model_id = None
        self.embedding_model_id = None
        self.embedding_dimension = None
        
    def setup_models(self):
        """Get available models and select appropriate ones for LLM and embeddings."""
        models = self.client.models.list()
    
        # Select embedding model
        embedding_models = [m for m in models if m.model_type == "embedding"]
        if not embedding_models:
            raise ValueError("No embedding models found")
        self.embedding_model_id = embedding_models[0].identifier
        self.embedding_dimension = embedding_models[0].metadata["embedding_dimension"]
        
    def register_vector_db(self):
        print(f"Registering Milvus vector database: {self.vector_db_id}")
        
        response = self.client.vector_dbs.register(
            vector_db_id=self.vector_db_id,
            embedding_model=self.embedding_model_id,
            embedding_dimension=self.embedding_dimension,
            provider_id="milvus-remote",  # Use remote Milvus
        )
        print(f"Vector database registered successfully")
        return response
        
    def insert_documents(self):
        """Insert sample documents into the vector database."""
        print("\nInserting sample documents...")
        
        # Sample documents about different topics
        documents = [
            RAGDocument(
                document_id="ai_ml_basics",
                content="""
                Artificial Intelligence (AI) and Machine Learning (ML) are transforming the world.
                AI refers to the simulation of human intelligence in machines, while ML is a subset
                of AI that enables computers to learn and improve from experience without being
                explicitly programmed. Deep learning, a subset of ML, uses neural networks with
                multiple layers to process complex patterns in data.
                
                Key concepts in AI/ML include:
                - Supervised Learning: Training with labeled data
                - Unsupervised Learning: Finding patterns in unlabeled data
                - Reinforcement Learning: Learning through trial and error
                - Neural Networks: Computing systems inspired by biological brains
                """,
                mime_type="text/plain",
                metadata={"topic": "technology", "category": "ai_ml"},
            ),
        ]
        
        # Insert documents with chunking
        self.client.tool_runtime.rag_tool.insert(
            documents=documents,
            vector_db_id=self.vector_db_id,
            chunk_size_in_tokens=200,  # Smaller chunks for better granularity
        )
        print(f"Inserted {len(documents)} documents with chunking")
                
    def test_keyword_search(self):
        """Test keyword-based search using BM25."""
        
        queries = [
            "neural networks",
            "Python frameworks",
            "data cleaning",
        ]
        
        for query in queries:
            response = self.client.vector_io.query(
                vector_db_id=self.vector_db_id,
                query=query,
                params={
                    "mode": "keyword",  # Keyword search
                    "max_chunks": 3,
                    "score_threshold": 0.0,
                }
            )
            
            for i, (chunk, score) in enumerate(zip(response.chunks, response.scores)):
                print(f"  {i+1}. Score: {score:.4f}")
                print(f"     Content: {chunk.content[:100]}...")
                print(f"     Metadata: {chunk.metadata}")    

                
    def run_demo(self):       
        try:
            self.setup_models()
            self.register_vector_db()
            self.insert_documents()
            self.test_keyword_search()
        except Exception as e:
            print(f"Error during demo: {e}")
            raise


def main():
    """Main function to run the demo."""
    # Check if Llama Stack server is running
    demo = MilvusRAGDemo()    
    try:
        demo.run_demo()
    except Exception as e:
        print(f"Demo failed: {e}")

if __name__ == "__main__":
    main()
```

[//]: # (## Documentation)

---------

Signed-off-by: Varsha Prasad Narsing <varshaprasad96@gmail.com>
This commit is contained in:
Varsha 2025-07-14 16:39:55 -07:00 committed by GitHub
parent 33f0d83ad3
commit 4ae5656c2f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 331 additions and 8 deletions

View file

@ -114,7 +114,7 @@ For more details on TLS configuration, refer to the [TLS setup guide](https://mi
| `uri` | `<class 'str'>` | No | PydanticUndefined | The URI of the Milvus server |
| `token` | `str \| None` | No | PydanticUndefined | The token of the Milvus server |
| `consistency_level` | `<class 'str'>` | No | Strong | The consistency level of the Milvus server |
| `kvstore` | `utils.kvstore.config.RedisKVStoreConfig \| utils.kvstore.config.SqliteKVStoreConfig \| utils.kvstore.config.PostgresKVStoreConfig \| utils.kvstore.config.MongoDBKVStoreConfig, annotation=NoneType, required=False, default='sqlite', discriminator='type'` | No | | Config for KV store backend (SQLite only for now) |
| `kvstore` | `utils.kvstore.config.RedisKVStoreConfig \| utils.kvstore.config.SqliteKVStoreConfig \| utils.kvstore.config.PostgresKVStoreConfig \| utils.kvstore.config.MongoDBKVStoreConfig` | No | sqlite | Config for KV store backend |
| `config` | `dict` | No | {} | This configuration allows additional fields to be passed through to the underlying Milvus client. See the [Milvus](https://milvus.io/docs/install-overview.md) documentation for more details about Milvus in general. |
> **Note**: This configuration class accepts additional fields beyond those listed above. You can pass any additional configuration options that will be forwarded to the underlying provider.
@ -124,6 +124,9 @@ For more details on TLS configuration, refer to the [TLS setup guide](https://mi
```yaml
uri: ${env.MILVUS_ENDPOINT}
token: ${env.MILVUS_TOKEN}
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/dummy}/milvus_remote_registry.db
```

View file

@ -8,7 +8,7 @@ from typing import Any
from pydantic import BaseModel, ConfigDict, Field
from llama_stack.providers.utils.kvstore.config import KVStoreConfig
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
from llama_stack.schema_utils import json_schema_type
@ -17,7 +17,7 @@ class MilvusVectorIOConfig(BaseModel):
uri: str = Field(description="The URI of the Milvus server")
token: str | None = Field(description="The token of the Milvus server")
consistency_level: str = Field(description="The consistency level of the Milvus server", default="Strong")
kvstore: KVStoreConfig | None = Field(description="Config for KV store backend (SQLite only for now)", default=None)
kvstore: KVStoreConfig = Field(description="Config for KV store backend")
# This configuration allows additional fields to be passed through to the underlying Milvus client.
# See the [Milvus](https://milvus.io/docs/install-overview.md) documentation for more details about Milvus in general.
@ -25,4 +25,11 @@ class MilvusVectorIOConfig(BaseModel):
@classmethod
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
return {"uri": "${env.MILVUS_ENDPOINT}", "token": "${env.MILVUS_TOKEN}"}
return {
"uri": "${env.MILVUS_ENDPOINT}",
"token": "${env.MILVUS_TOKEN}",
"kvstore": SqliteKVStoreConfig.sample_run_config(
__distro_dir__=__distro_dir__,
db_name="milvus_remote_registry.db",
),
}

View file

@ -12,7 +12,7 @@ import re
from typing import Any
from numpy.typing import NDArray
from pymilvus import DataType, MilvusClient
from pymilvus import DataType, Function, FunctionType, MilvusClient
from llama_stack.apis.files.files import Files
from llama_stack.apis.inference import Inference, InterleavedContent
@ -74,12 +74,66 @@ class MilvusIndex(EmbeddingIndex):
assert len(chunks) == len(embeddings), (
f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}"
)
if not await asyncio.to_thread(self.client.has_collection, self.collection_name):
logger.info(f"Creating new collection {self.collection_name} with nullable sparse field")
# Create schema for vector search
schema = self.client.create_schema()
schema.add_field(
field_name="chunk_id",
datatype=DataType.VARCHAR,
is_primary=True,
max_length=100,
)
schema.add_field(
field_name="content",
datatype=DataType.VARCHAR,
max_length=65535,
enable_analyzer=True, # Enable text analysis for BM25
)
schema.add_field(
field_name="vector",
datatype=DataType.FLOAT_VECTOR,
dim=len(embeddings[0]),
)
schema.add_field(
field_name="chunk_content",
datatype=DataType.JSON,
)
# Add sparse vector field for BM25 (required by the function)
schema.add_field(
field_name="sparse",
datatype=DataType.SPARSE_FLOAT_VECTOR,
)
# Create indexes
index_params = self.client.prepare_index_params()
index_params.add_index(
field_name="vector",
index_type="FLAT",
metric_type="COSINE",
)
# Add index for sparse field (required by BM25 function)
index_params.add_index(
field_name="sparse",
index_type="SPARSE_INVERTED_INDEX",
metric_type="BM25",
)
# Add BM25 function for full-text search
bm25_function = Function(
name="text_bm25_emb",
input_field_names=["content"],
output_field_names=["sparse"],
function_type=FunctionType.BM25,
)
schema.add_function(bm25_function)
await asyncio.to_thread(
self.client.create_collection,
self.collection_name,
dimension=len(embeddings[0]),
auto_id=True,
schema=schema,
index_params=index_params,
consistency_level=self.consistency_level,
)
@ -88,8 +142,10 @@ class MilvusIndex(EmbeddingIndex):
data.append(
{
"chunk_id": chunk.chunk_id,
"content": chunk.content,
"vector": embedding,
"chunk_content": chunk.model_dump(),
# sparse field will be handled by BM25 function automatically
}
)
try:
@ -107,6 +163,7 @@ class MilvusIndex(EmbeddingIndex):
self.client.search,
collection_name=self.collection_name,
data=[embedding],
anns_field="vector",
limit=k,
output_fields=["*"],
search_params={"params": {"radius": score_threshold}},
@ -121,7 +178,64 @@ class MilvusIndex(EmbeddingIndex):
k: int,
score_threshold: float,
) -> QueryChunksResponse:
raise NotImplementedError("Keyword search is not supported in Milvus")
"""
Perform BM25-based keyword search using Milvus's built-in full-text search.
"""
try:
# Use Milvus's built-in BM25 search
search_res = await asyncio.to_thread(
self.client.search,
collection_name=self.collection_name,
data=[query_string], # Raw text query
anns_field="sparse", # Use sparse field for BM25
output_fields=["chunk_content"], # Output the chunk content
limit=k,
search_params={
"params": {
"drop_ratio_search": 0.2, # Ignore low-importance terms
}
},
)
chunks = []
scores = []
for res in search_res[0]:
chunk = Chunk(**res["entity"]["chunk_content"])
chunks.append(chunk)
scores.append(res["distance"]) # BM25 score from Milvus
# Filter by score threshold
filtered_chunks = [chunk for chunk, score in zip(chunks, scores, strict=False) if score >= score_threshold]
filtered_scores = [score for score in scores if score >= score_threshold]
return QueryChunksResponse(chunks=filtered_chunks, scores=filtered_scores)
except Exception as e:
logger.error(f"Error performing BM25 search: {e}")
# Fallback to simple text search
return await self._fallback_keyword_search(query_string, k, score_threshold)
async def _fallback_keyword_search(
self,
query_string: str,
k: int,
score_threshold: float,
) -> QueryChunksResponse:
"""
Fallback to simple text search when BM25 search is not available.
"""
# Simple text search using content field
search_res = await asyncio.to_thread(
self.client.query,
collection_name=self.collection_name,
filter='content like "%{content}%"',
filter_params={"content": query_string},
output_fields=["*"],
limit=k,
)
chunks = [Chunk(**res["chunk_content"]) for res in search_res]
scores = [1.0] * len(chunks) # Simple binary score for text search
return QueryChunksResponse(chunks=chunks, scores=scores)
async def query_hybrid(
self,
@ -247,6 +361,14 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
if not index:
raise ValueError(f"Vector DB {vector_db_id} not found")
if params and params.get("mode") == "keyword":
# Check if this is inline Milvus (Milvus-Lite)
if hasattr(self.config, "db_path"):
raise NotImplementedError(
"Keyword search is not supported in Milvus-Lite. "
"Please use a remote Milvus server for keyword search functionality."
)
return await index.query_chunks(query, params)
async def _save_openai_vector_store_file(

View file

@ -0,0 +1,191 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from unittest.mock import MagicMock, patch
import numpy as np
import pytest
import pytest_asyncio
from llama_stack.apis.vector_io import QueryChunksResponse
# Mock the entire pymilvus module
pymilvus_mock = MagicMock()
pymilvus_mock.DataType = MagicMock()
pymilvus_mock.MilvusClient = MagicMock
# Apply the mock before importing MilvusIndex
with patch.dict("sys.modules", {"pymilvus": pymilvus_mock}):
from llama_stack.providers.remote.vector_io.milvus.milvus import MilvusIndex
# This test is a unit test for the MilvusVectorIOAdapter class. This should only contain
# tests which are specific to this class. More general (API-level) tests should be placed in
# tests/integration/vector_io/
#
# How to run this test:
#
# pytest tests/unit/providers/vector_io/test_milvus.py \
# -v -s --tb=short --disable-warnings --asyncio-mode=auto
MILVUS_PROVIDER = "milvus"
@pytest_asyncio.fixture
async def mock_milvus_client() -> MagicMock:
"""Create a mock Milvus client with common method behaviors."""
client = MagicMock()
# Mock collection operations
client.has_collection.return_value = False # Initially no collection
client.create_collection.return_value = None
client.drop_collection.return_value = None
# Mock insert operation
client.insert.return_value = {"insert_count": 10}
# Mock search operation - return mock results (data should be dict, not JSON string)
client.search.return_value = [
[
{
"id": 0,
"distance": 0.1,
"entity": {"chunk_content": {"content": "mock chunk 1", "metadata": {"document_id": "doc1"}}},
},
{
"id": 1,
"distance": 0.2,
"entity": {"chunk_content": {"content": "mock chunk 2", "metadata": {"document_id": "doc2"}}},
},
]
]
# Mock query operation for keyword search (data should be dict, not JSON string)
client.query.return_value = [
{
"chunk_id": "chunk1",
"chunk_content": {"content": "mock chunk 1", "metadata": {"document_id": "doc1"}},
"score": 0.9,
},
{
"chunk_id": "chunk2",
"chunk_content": {"content": "mock chunk 2", "metadata": {"document_id": "doc2"}},
"score": 0.8,
},
{
"chunk_id": "chunk3",
"chunk_content": {"content": "mock chunk 3", "metadata": {"document_id": "doc3"}},
"score": 0.7,
},
]
return client
@pytest_asyncio.fixture
async def milvus_index(mock_milvus_client):
"""Create a MilvusIndex with mocked client."""
index = MilvusIndex(client=mock_milvus_client, collection_name="test_collection")
yield index
# No real cleanup needed since we're using mocks
@pytest.mark.asyncio
async def test_add_chunks(milvus_index, sample_chunks, sample_embeddings, mock_milvus_client):
# Setup: collection doesn't exist initially, then exists after creation
mock_milvus_client.has_collection.side_effect = [False, True]
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
# Verify collection was created and data was inserted
mock_milvus_client.create_collection.assert_called_once()
mock_milvus_client.insert.assert_called_once()
# Verify the insert call had the right number of chunks
insert_call = mock_milvus_client.insert.call_args
assert len(insert_call[1]["data"]) == len(sample_chunks)
@pytest.mark.asyncio
async def test_query_chunks_vector(
milvus_index, sample_chunks, sample_embeddings, embedding_dimension, mock_milvus_client
):
# Setup: Add chunks first
mock_milvus_client.has_collection.return_value = True
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
# Test vector search
query_embedding = np.random.rand(embedding_dimension).astype(np.float32)
response = await milvus_index.query_vector(query_embedding, k=2, score_threshold=0.0)
assert isinstance(response, QueryChunksResponse)
assert len(response.chunks) == 2
mock_milvus_client.search.assert_called_once()
@pytest.mark.asyncio
async def test_query_chunks_keyword_search(milvus_index, sample_chunks, sample_embeddings, mock_milvus_client):
mock_milvus_client.has_collection.return_value = True
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
# Test keyword search
query_string = "Sentence 5"
response = await milvus_index.query_keyword(query_string=query_string, k=2, score_threshold=0.0)
assert isinstance(response, QueryChunksResponse)
assert len(response.chunks) == 2
@pytest.mark.asyncio
async def test_bm25_fallback_to_simple_search(milvus_index, sample_chunks, sample_embeddings, mock_milvus_client):
"""Test that when BM25 search fails, the system falls back to simple text search."""
mock_milvus_client.has_collection.return_value = True
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
# Force BM25 search to fail
mock_milvus_client.search.side_effect = Exception("BM25 search not available")
# Mock simple text search results
mock_milvus_client.query.return_value = [
{
"chunk_id": "chunk1",
"chunk_content": {"content": "Python programming language", "metadata": {"document_id": "doc1"}},
},
{
"chunk_id": "chunk2",
"chunk_content": {"content": "Machine learning algorithms", "metadata": {"document_id": "doc2"}},
},
]
# Test keyword search that should fall back to simple text search
query_string = "Python"
response = await milvus_index.query_keyword(query_string=query_string, k=3, score_threshold=0.0)
# Verify response structure
assert isinstance(response, QueryChunksResponse)
assert len(response.chunks) > 0, "Fallback search should return results"
# Verify that simple text search was used (query method called instead of search)
mock_milvus_client.query.assert_called_once()
mock_milvus_client.search.assert_called_once() # Called once but failed
# Verify the query uses parameterized filter with filter_params
query_call_args = mock_milvus_client.query.call_args
assert "filter" in query_call_args[1], "Query should include filter for text search"
assert "filter_params" in query_call_args[1], "Query should use parameterized filter"
assert query_call_args[1]["filter_params"]["content"] == "Python", "Filter params should contain the search term"
# Verify all returned chunks have score 1.0 (simple binary scoring)
assert all(score == 1.0 for score in response.scores), "Simple text search should use binary scoring"
@pytest.mark.asyncio
async def test_delete_collection(milvus_index, mock_milvus_client):
# Test collection deletion
mock_milvus_client.has_collection.return_value = True
await milvus_index.delete()
mock_milvus_client.drop_collection.assert_called_once_with(collection_name=milvus_index.collection_name)