Fix adding index for BM25 func

Signed-off-by: Varsha Prasad Narsing <varshaprasad96@gmail.com>
This commit is contained in:
Varsha Prasad Narsing 2025-07-12 17:09:50 -07:00 committed by Varsha
parent ac039e6bac
commit 143bc7eb74
4 changed files with 30 additions and 10 deletions

View file

@ -114,7 +114,7 @@ For more details on TLS configuration, refer to the [TLS setup guide](https://mi
| `uri` | `<class 'str'>` | No | PydanticUndefined | The URI of the Milvus server | | `uri` | `<class 'str'>` | No | PydanticUndefined | The URI of the Milvus server |
| `token` | `str \| None` | No | PydanticUndefined | The token of the Milvus server | | `token` | `str \| None` | No | PydanticUndefined | The token of the Milvus server |
| `consistency_level` | `<class 'str'>` | No | Strong | The consistency level of the Milvus server | | `consistency_level` | `<class 'str'>` | No | Strong | The consistency level of the Milvus server |
| `kvstore` | `utils.kvstore.config.RedisKVStoreConfig \| utils.kvstore.config.SqliteKVStoreConfig \| utils.kvstore.config.PostgresKVStoreConfig \| utils.kvstore.config.MongoDBKVStoreConfig, annotation=NoneType, required=False, default='sqlite', discriminator='type'` | No | | Config for KV store backend (SQLite only for now) | | `kvstore` | `utils.kvstore.config.RedisKVStoreConfig \| utils.kvstore.config.SqliteKVStoreConfig \| utils.kvstore.config.PostgresKVStoreConfig \| utils.kvstore.config.MongoDBKVStoreConfig` | No | sqlite | Config for KV store backend |
| `config` | `dict` | No | {} | This configuration allows additional fields to be passed through to the underlying Milvus client. See the [Milvus](https://milvus.io/docs/install-overview.md) documentation for more details about Milvus in general. | | `config` | `dict` | No | {} | This configuration allows additional fields to be passed through to the underlying Milvus client. See the [Milvus](https://milvus.io/docs/install-overview.md) documentation for more details about Milvus in general. |
> **Note**: This configuration class accepts additional fields beyond those listed above. You can pass any additional configuration options that will be forwarded to the underlying provider. > **Note**: This configuration class accepts additional fields beyond those listed above. You can pass any additional configuration options that will be forwarded to the underlying provider.
@ -124,6 +124,9 @@ For more details on TLS configuration, refer to the [TLS setup guide](https://mi
```yaml ```yaml
uri: ${env.MILVUS_ENDPOINT} uri: ${env.MILVUS_ENDPOINT}
token: ${env.MILVUS_TOKEN} token: ${env.MILVUS_TOKEN}
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/dummy}/milvus_remote_registry.db
``` ```

View file

@ -8,7 +8,7 @@ from typing import Any
from pydantic import BaseModel, ConfigDict, Field from pydantic import BaseModel, ConfigDict, Field
from llama_stack.providers.utils.kvstore.config import KVStoreConfig from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
from llama_stack.schema_utils import json_schema_type from llama_stack.schema_utils import json_schema_type
@ -17,7 +17,7 @@ class MilvusVectorIOConfig(BaseModel):
uri: str = Field(description="The URI of the Milvus server") uri: str = Field(description="The URI of the Milvus server")
token: str | None = Field(description="The token of the Milvus server") token: str | None = Field(description="The token of the Milvus server")
consistency_level: str = Field(description="The consistency level of the Milvus server", default="Strong") consistency_level: str = Field(description="The consistency level of the Milvus server", default="Strong")
kvstore: KVStoreConfig | None = Field(description="Config for KV store backend (SQLite only for now)", default=None) kvstore: KVStoreConfig = Field(description="Config for KV store backend")
# This configuration allows additional fields to be passed through to the underlying Milvus client. # This configuration allows additional fields to be passed through to the underlying Milvus client.
# See the [Milvus](https://milvus.io/docs/install-overview.md) documentation for more details about Milvus in general. # See the [Milvus](https://milvus.io/docs/install-overview.md) documentation for more details about Milvus in general.
@ -25,4 +25,11 @@ class MilvusVectorIOConfig(BaseModel):
@classmethod @classmethod
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]: def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
return {"uri": "${env.MILVUS_ENDPOINT}", "token": "${env.MILVUS_TOKEN}"} return {
"uri": "${env.MILVUS_ENDPOINT}",
"token": "${env.MILVUS_TOKEN}",
"kvstore": SqliteKVStoreConfig.sample_run_config(
__distro_dir__=__distro_dir__,
db_name="milvus_remote_registry.db",
),
}

View file

@ -74,7 +74,9 @@ class MilvusIndex(EmbeddingIndex):
assert len(chunks) == len(embeddings), ( assert len(chunks) == len(embeddings), (
f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}" f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}"
) )
if not await asyncio.to_thread(self.client.has_collection, self.collection_name): if not await asyncio.to_thread(self.client.has_collection, self.collection_name):
logger.info(f"Creating new collection {self.collection_name} with nullable sparse field")
# Create schema for vector search # Create schema for vector search
schema = self.client.create_schema() schema = self.client.create_schema()
schema.add_field( schema.add_field(
@ -98,7 +100,7 @@ class MilvusIndex(EmbeddingIndex):
field_name="chunk_content", field_name="chunk_content",
datatype=DataType.JSON, datatype=DataType.JSON,
) )
# Add sparse vector field for BM25 # Add sparse vector field for BM25 (required by the function)
schema.add_field( schema.add_field(
field_name="sparse", field_name="sparse",
datatype=DataType.SPARSE_FLOAT_VECTOR, datatype=DataType.SPARSE_FLOAT_VECTOR,
@ -111,6 +113,12 @@ class MilvusIndex(EmbeddingIndex):
index_type="FLAT", index_type="FLAT",
metric_type="COSINE", metric_type="COSINE",
) )
# Add index for sparse field (required by BM25 function)
index_params.add_index(
field_name="sparse",
index_type="SPARSE_INVERTED_INDEX",
metric_type="BM25",
)
# Add BM25 function for full-text search # Add BM25 function for full-text search
bm25_function = Function( bm25_function = Function(
@ -137,7 +145,7 @@ class MilvusIndex(EmbeddingIndex):
"content": chunk.content, "content": chunk.content,
"vector": embedding, "vector": embedding,
"chunk_content": chunk.model_dump(), "chunk_content": chunk.model_dump(),
# sparse field will be automatically populated by BM25 function # sparse field will be handled by BM25 function automatically
} }
) )
try: try:
@ -220,7 +228,8 @@ class MilvusIndex(EmbeddingIndex):
search_res = await asyncio.to_thread( search_res = await asyncio.to_thread(
self.client.query, self.client.query,
collection_name=self.collection_name, collection_name=self.collection_name,
filter=f'content like "%{query_string}%"', filter='content like "%{content}%"',
filter_params={"content": query_string},
output_fields=["*"], output_fields=["*"],
limit=k, limit=k,
) )

View file

@ -34,7 +34,7 @@ MILVUS_PROVIDER = "milvus"
@pytest_asyncio.fixture @pytest_asyncio.fixture
async def mock_milvus_client(): async def mock_milvus_client() -> MagicMock:
"""Create a mock Milvus client with common method behaviors.""" """Create a mock Milvus client with common method behaviors."""
client = MagicMock() client = MagicMock()
@ -171,10 +171,11 @@ async def test_bm25_fallback_to_simple_search(milvus_index, sample_chunks, sampl
mock_milvus_client.query.assert_called_once() mock_milvus_client.query.assert_called_once()
mock_milvus_client.search.assert_called_once() # Called once but failed mock_milvus_client.search.assert_called_once() # Called once but failed
# Verify the query filter contains the search term # Verify the query uses parameterized filter with filter_params
query_call_args = mock_milvus_client.query.call_args query_call_args = mock_milvus_client.query.call_args
assert "filter" in query_call_args[1], "Query should include filter for text search" assert "filter" in query_call_args[1], "Query should include filter for text search"
assert "Python" in query_call_args[1]["filter"], "Filter should contain the search term" assert "filter_params" in query_call_args[1], "Query should use parameterized filter"
assert query_call_args[1]["filter_params"]["content"] == "Python", "Filter params should contain the search term"
# Verify all returned chunks have score 1.0 (simple binary scoring) # Verify all returned chunks have score 1.0 (simple binary scoring)
assert all(score == 1.0 for score in response.scores), "Simple text search should use binary scoring" assert all(score == 1.0 for score in response.scores), "Simple text search should use binary scoring"