From 143bc7eb74853802fd4686d1a4503c5d6a48042f Mon Sep 17 00:00:00 2001 From: Varsha Prasad Narsing Date: Sat, 12 Jul 2025 17:09:50 -0700 Subject: [PATCH] Fix adding index for BM25 func Signed-off-by: Varsha Prasad Narsing --- docs/source/providers/vector_io/remote_milvus.md | 5 ++++- .../providers/remote/vector_io/milvus/config.py | 13 ++++++++++--- .../providers/remote/vector_io/milvus/milvus.py | 15 ++++++++++++--- .../providers/vector_io/remote/test_milvus.py | 7 ++++--- 4 files changed, 30 insertions(+), 10 deletions(-) diff --git a/docs/source/providers/vector_io/remote_milvus.md b/docs/source/providers/vector_io/remote_milvus.md index f3089e615..6734d8315 100644 --- a/docs/source/providers/vector_io/remote_milvus.md +++ b/docs/source/providers/vector_io/remote_milvus.md @@ -114,7 +114,7 @@ For more details on TLS configuration, refer to the [TLS setup guide](https://mi | `uri` | `` | No | PydanticUndefined | The URI of the Milvus server | | `token` | `str \| None` | No | PydanticUndefined | The token of the Milvus server | | `consistency_level` | `` | No | Strong | The consistency level of the Milvus server | -| `kvstore` | `utils.kvstore.config.RedisKVStoreConfig \| utils.kvstore.config.SqliteKVStoreConfig \| utils.kvstore.config.PostgresKVStoreConfig \| utils.kvstore.config.MongoDBKVStoreConfig, annotation=NoneType, required=False, default='sqlite', discriminator='type'` | No | | Config for KV store backend (SQLite only for now) | +| `kvstore` | `utils.kvstore.config.RedisKVStoreConfig \| utils.kvstore.config.SqliteKVStoreConfig \| utils.kvstore.config.PostgresKVStoreConfig \| utils.kvstore.config.MongoDBKVStoreConfig` | No | sqlite | Config for KV store backend | | `config` | `dict` | No | {} | This configuration allows additional fields to be passed through to the underlying Milvus client. See the [Milvus](https://milvus.io/docs/install-overview.md) documentation for more details about Milvus in general. | > **Note**: This configuration class accepts additional fields beyond those listed above. You can pass any additional configuration options that will be forwarded to the underlying provider. @@ -124,6 +124,9 @@ For more details on TLS configuration, refer to the [TLS setup guide](https://mi ```yaml uri: ${env.MILVUS_ENDPOINT} token: ${env.MILVUS_TOKEN} +kvstore: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/dummy}/milvus_remote_registry.db ``` diff --git a/llama_stack/providers/remote/vector_io/milvus/config.py b/llama_stack/providers/remote/vector_io/milvus/config.py index e3f51b4f4..899d3678d 100644 --- a/llama_stack/providers/remote/vector_io/milvus/config.py +++ b/llama_stack/providers/remote/vector_io/milvus/config.py @@ -8,7 +8,7 @@ from typing import Any from pydantic import BaseModel, ConfigDict, Field -from llama_stack.providers.utils.kvstore.config import KVStoreConfig +from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig from llama_stack.schema_utils import json_schema_type @@ -17,7 +17,7 @@ class MilvusVectorIOConfig(BaseModel): uri: str = Field(description="The URI of the Milvus server") token: str | None = Field(description="The token of the Milvus server") consistency_level: str = Field(description="The consistency level of the Milvus server", default="Strong") - kvstore: KVStoreConfig | None = Field(description="Config for KV store backend (SQLite only for now)", default=None) + kvstore: KVStoreConfig = Field(description="Config for KV store backend") # This configuration allows additional fields to be passed through to the underlying Milvus client. # See the [Milvus](https://milvus.io/docs/install-overview.md) documentation for more details about Milvus in general. @@ -25,4 +25,11 @@ class MilvusVectorIOConfig(BaseModel): @classmethod def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]: - return {"uri": "${env.MILVUS_ENDPOINT}", "token": "${env.MILVUS_TOKEN}"} + return { + "uri": "${env.MILVUS_ENDPOINT}", + "token": "${env.MILVUS_TOKEN}", + "kvstore": SqliteKVStoreConfig.sample_run_config( + __distro_dir__=__distro_dir__, + db_name="milvus_remote_registry.db", + ), + } diff --git a/llama_stack/providers/remote/vector_io/milvus/milvus.py b/llama_stack/providers/remote/vector_io/milvus/milvus.py index 5fd065a9a..290564953 100644 --- a/llama_stack/providers/remote/vector_io/milvus/milvus.py +++ b/llama_stack/providers/remote/vector_io/milvus/milvus.py @@ -74,7 +74,9 @@ class MilvusIndex(EmbeddingIndex): assert len(chunks) == len(embeddings), ( f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}" ) + if not await asyncio.to_thread(self.client.has_collection, self.collection_name): + logger.info(f"Creating new collection {self.collection_name} with nullable sparse field") # Create schema for vector search schema = self.client.create_schema() schema.add_field( @@ -98,7 +100,7 @@ class MilvusIndex(EmbeddingIndex): field_name="chunk_content", datatype=DataType.JSON, ) - # Add sparse vector field for BM25 + # Add sparse vector field for BM25 (required by the function) schema.add_field( field_name="sparse", datatype=DataType.SPARSE_FLOAT_VECTOR, @@ -111,6 +113,12 @@ class MilvusIndex(EmbeddingIndex): index_type="FLAT", metric_type="COSINE", ) + # Add index for sparse field (required by BM25 function) + index_params.add_index( + field_name="sparse", + index_type="SPARSE_INVERTED_INDEX", + metric_type="BM25", + ) # Add BM25 function for full-text search bm25_function = Function( @@ -137,7 +145,7 @@ class MilvusIndex(EmbeddingIndex): "content": chunk.content, "vector": embedding, "chunk_content": chunk.model_dump(), - # sparse field will be automatically populated by BM25 function + # sparse field will be handled by BM25 function automatically } ) try: @@ -220,7 +228,8 @@ class MilvusIndex(EmbeddingIndex): search_res = await asyncio.to_thread( self.client.query, collection_name=self.collection_name, - filter=f'content like "%{query_string}%"', + filter='content like "%{content}%"', + filter_params={"content": query_string}, output_fields=["*"], limit=k, ) diff --git a/tests/unit/providers/vector_io/remote/test_milvus.py b/tests/unit/providers/vector_io/remote/test_milvus.py index 8e4366b99..2f212e374 100644 --- a/tests/unit/providers/vector_io/remote/test_milvus.py +++ b/tests/unit/providers/vector_io/remote/test_milvus.py @@ -34,7 +34,7 @@ MILVUS_PROVIDER = "milvus" @pytest_asyncio.fixture -async def mock_milvus_client(): +async def mock_milvus_client() -> MagicMock: """Create a mock Milvus client with common method behaviors.""" client = MagicMock() @@ -171,10 +171,11 @@ async def test_bm25_fallback_to_simple_search(milvus_index, sample_chunks, sampl mock_milvus_client.query.assert_called_once() mock_milvus_client.search.assert_called_once() # Called once but failed - # Verify the query filter contains the search term + # Verify the query uses parameterized filter with filter_params query_call_args = mock_milvus_client.query.call_args assert "filter" in query_call_args[1], "Query should include filter for text search" - assert "Python" in query_call_args[1]["filter"], "Filter should contain the search term" + assert "filter_params" in query_call_args[1], "Query should use parameterized filter" + assert query_call_args[1]["filter_params"]["content"] == "Python", "Filter params should contain the search term" # Verify all returned chunks have score 1.0 (simple binary scoring) assert all(score == 1.0 for score in response.scores), "Simple text search should use binary scoring"