feat: Implement keyword search in milvus (#2231)
Some checks failed
SqlStore Integration Tests / test-postgres (3.13) (push) Failing after 7s
Integration Tests / discover-tests (push) Successful in 8s
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 10s
Test Llama Stack Build / build-custom-container-distribution (push) Failing after 6s
Test Llama Stack Build / build-ubi9-container-distribution (push) Failing after 6s
Vector IO Integration Tests / test-matrix (3.12, remote::chromadb) (push) Failing after 11s
Vector IO Integration Tests / test-matrix (3.13, remote::chromadb) (push) Failing after 9s
Vector IO Integration Tests / test-matrix (3.12, remote::pgvector) (push) Failing after 10s
Vector IO Integration Tests / test-matrix (3.13, remote::pgvector) (push) Failing after 8s
Test Llama Stack Build / generate-matrix (push) Successful in 8s
Python Package Build Test / build (3.13) (push) Failing after 6s
Unit Tests / unit-tests (3.12) (push) Failing after 6s
Unit Tests / unit-tests (3.13) (push) Failing after 6s
Vector IO Integration Tests / test-matrix (3.12, inline::milvus) (push) Failing after 13s
Vector IO Integration Tests / test-matrix (3.13, inline::milvus) (push) Failing after 12s
Vector IO Integration Tests / test-matrix (3.12, inline::sqlite-vec) (push) Failing after 15s
Test External Providers / test-external-providers (venv) (push) Failing after 9s
Test Llama Stack Build / build-single-provider (push) Failing after 11s
Vector IO Integration Tests / test-matrix (3.13, inline::sqlite-vec) (push) Failing after 14s
SqlStore Integration Tests / test-postgres (3.12) (push) Failing after 19s
Integration Tests / test-matrix (push) Failing after 8s
Test Llama Stack Build / build (push) Failing after 5s
Python Package Build Test / build (3.12) (push) Failing after 51s
Vector IO Integration Tests / test-matrix (3.13, inline::faiss) (push) Failing after 55s
Vector IO Integration Tests / test-matrix (3.12, inline::faiss) (push) Failing after 57s
Update ReadTheDocs / update-readthedocs (push) Failing after 50s
Pre-commit / pre-commit (push) Successful in 2m9s

# What does this PR do?
This PR adds the keyword search implementation for Milvus. Along with
the implementation for remote Milvus, the tests require us to start a
Milvus containers locally.

In order to verify the implementation, run:
```
pytest tests/unit/providers/vector_io/remote/test_milvus.py -v -s --tb=short --disable-warnings --asyncio-mode=auto
```

You can also test the changes using the below script:
```
#!/usr/bin/env python3
import asyncio
import os
import uuid
from typing import List

from llama_stack_client import (
    Agent, 
    AgentEventLogger, 
    LlamaStackClient, 
    RAGDocument
)


class MilvusRAGDemo:
    def __init__(self, base_url: str = "http://localhost:8321/"):
        self.client = LlamaStackClient(base_url=base_url)
        self.vector_db_id = f"milvus_rag_demo_{uuid.uuid4().hex[:8]}"
        self.model_id = None
        self.embedding_model_id = None
        self.embedding_dimension = None
        
    def setup_models(self):
        """Get available models and select appropriate ones for LLM and embeddings."""
        models = self.client.models.list()
    
        # Select embedding model
        embedding_models = [m for m in models if m.model_type == "embedding"]
        if not embedding_models:
            raise ValueError("No embedding models found")
        self.embedding_model_id = embedding_models[0].identifier
        self.embedding_dimension = embedding_models[0].metadata["embedding_dimension"]
        
    def register_vector_db(self):
        print(f"Registering Milvus vector database: {self.vector_db_id}")
        
        response = self.client.vector_dbs.register(
            vector_db_id=self.vector_db_id,
            embedding_model=self.embedding_model_id,
            embedding_dimension=self.embedding_dimension,
            provider_id="milvus-remote",  # Use remote Milvus
        )
        print(f"Vector database registered successfully")
        return response
        
    def insert_documents(self):
        """Insert sample documents into the vector database."""
        print("\nInserting sample documents...")
        
        # Sample documents about different topics
        documents = [
            RAGDocument(
                document_id="ai_ml_basics",
                content="""
                Artificial Intelligence (AI) and Machine Learning (ML) are transforming the world.
                AI refers to the simulation of human intelligence in machines, while ML is a subset
                of AI that enables computers to learn and improve from experience without being
                explicitly programmed. Deep learning, a subset of ML, uses neural networks with
                multiple layers to process complex patterns in data.
                
                Key concepts in AI/ML include:
                - Supervised Learning: Training with labeled data
                - Unsupervised Learning: Finding patterns in unlabeled data
                - Reinforcement Learning: Learning through trial and error
                - Neural Networks: Computing systems inspired by biological brains
                """,
                mime_type="text/plain",
                metadata={"topic": "technology", "category": "ai_ml"},
            ),
        ]
        
        # Insert documents with chunking
        self.client.tool_runtime.rag_tool.insert(
            documents=documents,
            vector_db_id=self.vector_db_id,
            chunk_size_in_tokens=200,  # Smaller chunks for better granularity
        )
        print(f"Inserted {len(documents)} documents with chunking")
                
    def test_keyword_search(self):
        """Test keyword-based search using BM25."""
        
        queries = [
            "neural networks",
            "Python frameworks",
            "data cleaning",
        ]
        
        for query in queries:
            response = self.client.vector_io.query(
                vector_db_id=self.vector_db_id,
                query=query,
                params={
                    "mode": "keyword",  # Keyword search
                    "max_chunks": 3,
                    "score_threshold": 0.0,
                }
            )
            
            for i, (chunk, score) in enumerate(zip(response.chunks, response.scores)):
                print(f"  {i+1}. Score: {score:.4f}")
                print(f"     Content: {chunk.content[:100]}...")
                print(f"     Metadata: {chunk.metadata}")    

                
    def run_demo(self):       
        try:
            self.setup_models()
            self.register_vector_db()
            self.insert_documents()
            self.test_keyword_search()
        except Exception as e:
            print(f"Error during demo: {e}")
            raise


def main():
    """Main function to run the demo."""
    # Check if Llama Stack server is running
    demo = MilvusRAGDemo()    
    try:
        demo.run_demo()
    except Exception as e:
        print(f"Demo failed: {e}")

if __name__ == "__main__":
    main()
```

[//]: # (## Documentation)

---------

Signed-off-by: Varsha Prasad Narsing <varshaprasad96@gmail.com>
This commit is contained in:
Varsha 2025-07-14 16:39:55 -07:00 committed by GitHub
parent 33f0d83ad3
commit 4ae5656c2f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 331 additions and 8 deletions

View file

@ -8,7 +8,7 @@ from typing import Any
from pydantic import BaseModel, ConfigDict, Field
from llama_stack.providers.utils.kvstore.config import KVStoreConfig
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
from llama_stack.schema_utils import json_schema_type
@ -17,7 +17,7 @@ class MilvusVectorIOConfig(BaseModel):
uri: str = Field(description="The URI of the Milvus server")
token: str | None = Field(description="The token of the Milvus server")
consistency_level: str = Field(description="The consistency level of the Milvus server", default="Strong")
kvstore: KVStoreConfig | None = Field(description="Config for KV store backend (SQLite only for now)", default=None)
kvstore: KVStoreConfig = Field(description="Config for KV store backend")
# This configuration allows additional fields to be passed through to the underlying Milvus client.
# See the [Milvus](https://milvus.io/docs/install-overview.md) documentation for more details about Milvus in general.
@ -25,4 +25,11 @@ class MilvusVectorIOConfig(BaseModel):
@classmethod
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
return {"uri": "${env.MILVUS_ENDPOINT}", "token": "${env.MILVUS_TOKEN}"}
return {
"uri": "${env.MILVUS_ENDPOINT}",
"token": "${env.MILVUS_TOKEN}",
"kvstore": SqliteKVStoreConfig.sample_run_config(
__distro_dir__=__distro_dir__,
db_name="milvus_remote_registry.db",
),
}

View file

@ -12,7 +12,7 @@ import re
from typing import Any
from numpy.typing import NDArray
from pymilvus import DataType, MilvusClient
from pymilvus import DataType, Function, FunctionType, MilvusClient
from llama_stack.apis.files.files import Files
from llama_stack.apis.inference import Inference, InterleavedContent
@ -74,12 +74,66 @@ class MilvusIndex(EmbeddingIndex):
assert len(chunks) == len(embeddings), (
f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}"
)
if not await asyncio.to_thread(self.client.has_collection, self.collection_name):
logger.info(f"Creating new collection {self.collection_name} with nullable sparse field")
# Create schema for vector search
schema = self.client.create_schema()
schema.add_field(
field_name="chunk_id",
datatype=DataType.VARCHAR,
is_primary=True,
max_length=100,
)
schema.add_field(
field_name="content",
datatype=DataType.VARCHAR,
max_length=65535,
enable_analyzer=True, # Enable text analysis for BM25
)
schema.add_field(
field_name="vector",
datatype=DataType.FLOAT_VECTOR,
dim=len(embeddings[0]),
)
schema.add_field(
field_name="chunk_content",
datatype=DataType.JSON,
)
# Add sparse vector field for BM25 (required by the function)
schema.add_field(
field_name="sparse",
datatype=DataType.SPARSE_FLOAT_VECTOR,
)
# Create indexes
index_params = self.client.prepare_index_params()
index_params.add_index(
field_name="vector",
index_type="FLAT",
metric_type="COSINE",
)
# Add index for sparse field (required by BM25 function)
index_params.add_index(
field_name="sparse",
index_type="SPARSE_INVERTED_INDEX",
metric_type="BM25",
)
# Add BM25 function for full-text search
bm25_function = Function(
name="text_bm25_emb",
input_field_names=["content"],
output_field_names=["sparse"],
function_type=FunctionType.BM25,
)
schema.add_function(bm25_function)
await asyncio.to_thread(
self.client.create_collection,
self.collection_name,
dimension=len(embeddings[0]),
auto_id=True,
schema=schema,
index_params=index_params,
consistency_level=self.consistency_level,
)
@ -88,8 +142,10 @@ class MilvusIndex(EmbeddingIndex):
data.append(
{
"chunk_id": chunk.chunk_id,
"content": chunk.content,
"vector": embedding,
"chunk_content": chunk.model_dump(),
# sparse field will be handled by BM25 function automatically
}
)
try:
@ -107,6 +163,7 @@ class MilvusIndex(EmbeddingIndex):
self.client.search,
collection_name=self.collection_name,
data=[embedding],
anns_field="vector",
limit=k,
output_fields=["*"],
search_params={"params": {"radius": score_threshold}},
@ -121,7 +178,64 @@ class MilvusIndex(EmbeddingIndex):
k: int,
score_threshold: float,
) -> QueryChunksResponse:
raise NotImplementedError("Keyword search is not supported in Milvus")
"""
Perform BM25-based keyword search using Milvus's built-in full-text search.
"""
try:
# Use Milvus's built-in BM25 search
search_res = await asyncio.to_thread(
self.client.search,
collection_name=self.collection_name,
data=[query_string], # Raw text query
anns_field="sparse", # Use sparse field for BM25
output_fields=["chunk_content"], # Output the chunk content
limit=k,
search_params={
"params": {
"drop_ratio_search": 0.2, # Ignore low-importance terms
}
},
)
chunks = []
scores = []
for res in search_res[0]:
chunk = Chunk(**res["entity"]["chunk_content"])
chunks.append(chunk)
scores.append(res["distance"]) # BM25 score from Milvus
# Filter by score threshold
filtered_chunks = [chunk for chunk, score in zip(chunks, scores, strict=False) if score >= score_threshold]
filtered_scores = [score for score in scores if score >= score_threshold]
return QueryChunksResponse(chunks=filtered_chunks, scores=filtered_scores)
except Exception as e:
logger.error(f"Error performing BM25 search: {e}")
# Fallback to simple text search
return await self._fallback_keyword_search(query_string, k, score_threshold)
async def _fallback_keyword_search(
self,
query_string: str,
k: int,
score_threshold: float,
) -> QueryChunksResponse:
"""
Fallback to simple text search when BM25 search is not available.
"""
# Simple text search using content field
search_res = await asyncio.to_thread(
self.client.query,
collection_name=self.collection_name,
filter='content like "%{content}%"',
filter_params={"content": query_string},
output_fields=["*"],
limit=k,
)
chunks = [Chunk(**res["chunk_content"]) for res in search_res]
scores = [1.0] * len(chunks) # Simple binary score for text search
return QueryChunksResponse(chunks=chunks, scores=scores)
async def query_hybrid(
self,
@ -247,6 +361,14 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
if not index:
raise ValueError(f"Vector DB {vector_db_id} not found")
if params and params.get("mode") == "keyword":
# Check if this is inline Milvus (Milvus-Lite)
if hasattr(self.config, "db_path"):
raise NotImplementedError(
"Keyword search is not supported in Milvus-Lite. "
"Please use a remote Milvus server for keyword search functionality."
)
return await index.query_chunks(query, params)
async def _save_openai_vector_store_file(