mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-18 02:42:31 +00:00
feat: Implement keyword search in milvus (#2231)
Some checks failed
SqlStore Integration Tests / test-postgres (3.13) (push) Failing after 7s
Integration Tests / discover-tests (push) Successful in 8s
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 10s
Test Llama Stack Build / build-custom-container-distribution (push) Failing after 6s
Test Llama Stack Build / build-ubi9-container-distribution (push) Failing after 6s
Vector IO Integration Tests / test-matrix (3.12, remote::chromadb) (push) Failing after 11s
Vector IO Integration Tests / test-matrix (3.13, remote::chromadb) (push) Failing after 9s
Vector IO Integration Tests / test-matrix (3.12, remote::pgvector) (push) Failing after 10s
Vector IO Integration Tests / test-matrix (3.13, remote::pgvector) (push) Failing after 8s
Test Llama Stack Build / generate-matrix (push) Successful in 8s
Python Package Build Test / build (3.13) (push) Failing after 6s
Unit Tests / unit-tests (3.12) (push) Failing after 6s
Unit Tests / unit-tests (3.13) (push) Failing after 6s
Vector IO Integration Tests / test-matrix (3.12, inline::milvus) (push) Failing after 13s
Vector IO Integration Tests / test-matrix (3.13, inline::milvus) (push) Failing after 12s
Vector IO Integration Tests / test-matrix (3.12, inline::sqlite-vec) (push) Failing after 15s
Test External Providers / test-external-providers (venv) (push) Failing after 9s
Test Llama Stack Build / build-single-provider (push) Failing after 11s
Vector IO Integration Tests / test-matrix (3.13, inline::sqlite-vec) (push) Failing after 14s
SqlStore Integration Tests / test-postgres (3.12) (push) Failing after 19s
Integration Tests / test-matrix (push) Failing after 8s
Test Llama Stack Build / build (push) Failing after 5s
Python Package Build Test / build (3.12) (push) Failing after 51s
Vector IO Integration Tests / test-matrix (3.13, inline::faiss) (push) Failing after 55s
Vector IO Integration Tests / test-matrix (3.12, inline::faiss) (push) Failing after 57s
Update ReadTheDocs / update-readthedocs (push) Failing after 50s
Pre-commit / pre-commit (push) Successful in 2m9s
Some checks failed
SqlStore Integration Tests / test-postgres (3.13) (push) Failing after 7s
Integration Tests / discover-tests (push) Successful in 8s
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 10s
Test Llama Stack Build / build-custom-container-distribution (push) Failing after 6s
Test Llama Stack Build / build-ubi9-container-distribution (push) Failing after 6s
Vector IO Integration Tests / test-matrix (3.12, remote::chromadb) (push) Failing after 11s
Vector IO Integration Tests / test-matrix (3.13, remote::chromadb) (push) Failing after 9s
Vector IO Integration Tests / test-matrix (3.12, remote::pgvector) (push) Failing after 10s
Vector IO Integration Tests / test-matrix (3.13, remote::pgvector) (push) Failing after 8s
Test Llama Stack Build / generate-matrix (push) Successful in 8s
Python Package Build Test / build (3.13) (push) Failing after 6s
Unit Tests / unit-tests (3.12) (push) Failing after 6s
Unit Tests / unit-tests (3.13) (push) Failing after 6s
Vector IO Integration Tests / test-matrix (3.12, inline::milvus) (push) Failing after 13s
Vector IO Integration Tests / test-matrix (3.13, inline::milvus) (push) Failing after 12s
Vector IO Integration Tests / test-matrix (3.12, inline::sqlite-vec) (push) Failing after 15s
Test External Providers / test-external-providers (venv) (push) Failing after 9s
Test Llama Stack Build / build-single-provider (push) Failing after 11s
Vector IO Integration Tests / test-matrix (3.13, inline::sqlite-vec) (push) Failing after 14s
SqlStore Integration Tests / test-postgres (3.12) (push) Failing after 19s
Integration Tests / test-matrix (push) Failing after 8s
Test Llama Stack Build / build (push) Failing after 5s
Python Package Build Test / build (3.12) (push) Failing after 51s
Vector IO Integration Tests / test-matrix (3.13, inline::faiss) (push) Failing after 55s
Vector IO Integration Tests / test-matrix (3.12, inline::faiss) (push) Failing after 57s
Update ReadTheDocs / update-readthedocs (push) Failing after 50s
Pre-commit / pre-commit (push) Successful in 2m9s
# What does this PR do? This PR adds the keyword search implementation for Milvus. Along with the implementation for remote Milvus, the tests require us to start a Milvus containers locally. In order to verify the implementation, run: ``` pytest tests/unit/providers/vector_io/remote/test_milvus.py -v -s --tb=short --disable-warnings --asyncio-mode=auto ``` You can also test the changes using the below script: ``` #!/usr/bin/env python3 import asyncio import os import uuid from typing import List from llama_stack_client import ( Agent, AgentEventLogger, LlamaStackClient, RAGDocument ) class MilvusRAGDemo: def __init__(self, base_url: str = "http://localhost:8321/"): self.client = LlamaStackClient(base_url=base_url) self.vector_db_id = f"milvus_rag_demo_{uuid.uuid4().hex[:8]}" self.model_id = None self.embedding_model_id = None self.embedding_dimension = None def setup_models(self): """Get available models and select appropriate ones for LLM and embeddings.""" models = self.client.models.list() # Select embedding model embedding_models = [m for m in models if m.model_type == "embedding"] if not embedding_models: raise ValueError("No embedding models found") self.embedding_model_id = embedding_models[0].identifier self.embedding_dimension = embedding_models[0].metadata["embedding_dimension"] def register_vector_db(self): print(f"Registering Milvus vector database: {self.vector_db_id}") response = self.client.vector_dbs.register( vector_db_id=self.vector_db_id, embedding_model=self.embedding_model_id, embedding_dimension=self.embedding_dimension, provider_id="milvus-remote", # Use remote Milvus ) print(f"Vector database registered successfully") return response def insert_documents(self): """Insert sample documents into the vector database.""" print("\nInserting sample documents...") # Sample documents about different topics documents = [ RAGDocument( document_id="ai_ml_basics", content=""" Artificial Intelligence (AI) and Machine Learning (ML) are transforming the world. AI refers to the simulation of human intelligence in machines, while ML is a subset of AI that enables computers to learn and improve from experience without being explicitly programmed. Deep learning, a subset of ML, uses neural networks with multiple layers to process complex patterns in data. Key concepts in AI/ML include: - Supervised Learning: Training with labeled data - Unsupervised Learning: Finding patterns in unlabeled data - Reinforcement Learning: Learning through trial and error - Neural Networks: Computing systems inspired by biological brains """, mime_type="text/plain", metadata={"topic": "technology", "category": "ai_ml"}, ), ] # Insert documents with chunking self.client.tool_runtime.rag_tool.insert( documents=documents, vector_db_id=self.vector_db_id, chunk_size_in_tokens=200, # Smaller chunks for better granularity ) print(f"Inserted {len(documents)} documents with chunking") def test_keyword_search(self): """Test keyword-based search using BM25.""" queries = [ "neural networks", "Python frameworks", "data cleaning", ] for query in queries: response = self.client.vector_io.query( vector_db_id=self.vector_db_id, query=query, params={ "mode": "keyword", # Keyword search "max_chunks": 3, "score_threshold": 0.0, } ) for i, (chunk, score) in enumerate(zip(response.chunks, response.scores)): print(f" {i+1}. Score: {score:.4f}") print(f" Content: {chunk.content[:100]}...") print(f" Metadata: {chunk.metadata}") def run_demo(self): try: self.setup_models() self.register_vector_db() self.insert_documents() self.test_keyword_search() except Exception as e: print(f"Error during demo: {e}") raise def main(): """Main function to run the demo.""" # Check if Llama Stack server is running demo = MilvusRAGDemo() try: demo.run_demo() except Exception as e: print(f"Demo failed: {e}") if __name__ == "__main__": main() ``` [//]: # (## Documentation) --------- Signed-off-by: Varsha Prasad Narsing <varshaprasad96@gmail.com>
This commit is contained in:
parent
33f0d83ad3
commit
4ae5656c2f
4 changed files with 331 additions and 8 deletions
|
@ -114,7 +114,7 @@ For more details on TLS configuration, refer to the [TLS setup guide](https://mi
|
||||||
| `uri` | `<class 'str'>` | No | PydanticUndefined | The URI of the Milvus server |
|
| `uri` | `<class 'str'>` | No | PydanticUndefined | The URI of the Milvus server |
|
||||||
| `token` | `str \| None` | No | PydanticUndefined | The token of the Milvus server |
|
| `token` | `str \| None` | No | PydanticUndefined | The token of the Milvus server |
|
||||||
| `consistency_level` | `<class 'str'>` | No | Strong | The consistency level of the Milvus server |
|
| `consistency_level` | `<class 'str'>` | No | Strong | The consistency level of the Milvus server |
|
||||||
| `kvstore` | `utils.kvstore.config.RedisKVStoreConfig \| utils.kvstore.config.SqliteKVStoreConfig \| utils.kvstore.config.PostgresKVStoreConfig \| utils.kvstore.config.MongoDBKVStoreConfig, annotation=NoneType, required=False, default='sqlite', discriminator='type'` | No | | Config for KV store backend (SQLite only for now) |
|
| `kvstore` | `utils.kvstore.config.RedisKVStoreConfig \| utils.kvstore.config.SqliteKVStoreConfig \| utils.kvstore.config.PostgresKVStoreConfig \| utils.kvstore.config.MongoDBKVStoreConfig` | No | sqlite | Config for KV store backend |
|
||||||
| `config` | `dict` | No | {} | This configuration allows additional fields to be passed through to the underlying Milvus client. See the [Milvus](https://milvus.io/docs/install-overview.md) documentation for more details about Milvus in general. |
|
| `config` | `dict` | No | {} | This configuration allows additional fields to be passed through to the underlying Milvus client. See the [Milvus](https://milvus.io/docs/install-overview.md) documentation for more details about Milvus in general. |
|
||||||
|
|
||||||
> **Note**: This configuration class accepts additional fields beyond those listed above. You can pass any additional configuration options that will be forwarded to the underlying provider.
|
> **Note**: This configuration class accepts additional fields beyond those listed above. You can pass any additional configuration options that will be forwarded to the underlying provider.
|
||||||
|
@ -124,6 +124,9 @@ For more details on TLS configuration, refer to the [TLS setup guide](https://mi
|
||||||
```yaml
|
```yaml
|
||||||
uri: ${env.MILVUS_ENDPOINT}
|
uri: ${env.MILVUS_ENDPOINT}
|
||||||
token: ${env.MILVUS_TOKEN}
|
token: ${env.MILVUS_TOKEN}
|
||||||
|
kvstore:
|
||||||
|
type: sqlite
|
||||||
|
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/dummy}/milvus_remote_registry.db
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
|
@ -8,7 +8,7 @@ from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict, Field
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
|
|
||||||
from llama_stack.providers.utils.kvstore.config import KVStoreConfig
|
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
|
||||||
from llama_stack.schema_utils import json_schema_type
|
from llama_stack.schema_utils import json_schema_type
|
||||||
|
|
||||||
|
|
||||||
|
@ -17,7 +17,7 @@ class MilvusVectorIOConfig(BaseModel):
|
||||||
uri: str = Field(description="The URI of the Milvus server")
|
uri: str = Field(description="The URI of the Milvus server")
|
||||||
token: str | None = Field(description="The token of the Milvus server")
|
token: str | None = Field(description="The token of the Milvus server")
|
||||||
consistency_level: str = Field(description="The consistency level of the Milvus server", default="Strong")
|
consistency_level: str = Field(description="The consistency level of the Milvus server", default="Strong")
|
||||||
kvstore: KVStoreConfig | None = Field(description="Config for KV store backend (SQLite only for now)", default=None)
|
kvstore: KVStoreConfig = Field(description="Config for KV store backend")
|
||||||
|
|
||||||
# This configuration allows additional fields to be passed through to the underlying Milvus client.
|
# This configuration allows additional fields to be passed through to the underlying Milvus client.
|
||||||
# See the [Milvus](https://milvus.io/docs/install-overview.md) documentation for more details about Milvus in general.
|
# See the [Milvus](https://milvus.io/docs/install-overview.md) documentation for more details about Milvus in general.
|
||||||
|
@ -25,4 +25,11 @@ class MilvusVectorIOConfig(BaseModel):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
|
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
|
||||||
return {"uri": "${env.MILVUS_ENDPOINT}", "token": "${env.MILVUS_TOKEN}"}
|
return {
|
||||||
|
"uri": "${env.MILVUS_ENDPOINT}",
|
||||||
|
"token": "${env.MILVUS_TOKEN}",
|
||||||
|
"kvstore": SqliteKVStoreConfig.sample_run_config(
|
||||||
|
__distro_dir__=__distro_dir__,
|
||||||
|
db_name="milvus_remote_registry.db",
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
|
@ -12,7 +12,7 @@ import re
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from numpy.typing import NDArray
|
from numpy.typing import NDArray
|
||||||
from pymilvus import DataType, MilvusClient
|
from pymilvus import DataType, Function, FunctionType, MilvusClient
|
||||||
|
|
||||||
from llama_stack.apis.files.files import Files
|
from llama_stack.apis.files.files import Files
|
||||||
from llama_stack.apis.inference import Inference, InterleavedContent
|
from llama_stack.apis.inference import Inference, InterleavedContent
|
||||||
|
@ -74,12 +74,66 @@ class MilvusIndex(EmbeddingIndex):
|
||||||
assert len(chunks) == len(embeddings), (
|
assert len(chunks) == len(embeddings), (
|
||||||
f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}"
|
f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
if not await asyncio.to_thread(self.client.has_collection, self.collection_name):
|
if not await asyncio.to_thread(self.client.has_collection, self.collection_name):
|
||||||
|
logger.info(f"Creating new collection {self.collection_name} with nullable sparse field")
|
||||||
|
# Create schema for vector search
|
||||||
|
schema = self.client.create_schema()
|
||||||
|
schema.add_field(
|
||||||
|
field_name="chunk_id",
|
||||||
|
datatype=DataType.VARCHAR,
|
||||||
|
is_primary=True,
|
||||||
|
max_length=100,
|
||||||
|
)
|
||||||
|
schema.add_field(
|
||||||
|
field_name="content",
|
||||||
|
datatype=DataType.VARCHAR,
|
||||||
|
max_length=65535,
|
||||||
|
enable_analyzer=True, # Enable text analysis for BM25
|
||||||
|
)
|
||||||
|
schema.add_field(
|
||||||
|
field_name="vector",
|
||||||
|
datatype=DataType.FLOAT_VECTOR,
|
||||||
|
dim=len(embeddings[0]),
|
||||||
|
)
|
||||||
|
schema.add_field(
|
||||||
|
field_name="chunk_content",
|
||||||
|
datatype=DataType.JSON,
|
||||||
|
)
|
||||||
|
# Add sparse vector field for BM25 (required by the function)
|
||||||
|
schema.add_field(
|
||||||
|
field_name="sparse",
|
||||||
|
datatype=DataType.SPARSE_FLOAT_VECTOR,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create indexes
|
||||||
|
index_params = self.client.prepare_index_params()
|
||||||
|
index_params.add_index(
|
||||||
|
field_name="vector",
|
||||||
|
index_type="FLAT",
|
||||||
|
metric_type="COSINE",
|
||||||
|
)
|
||||||
|
# Add index for sparse field (required by BM25 function)
|
||||||
|
index_params.add_index(
|
||||||
|
field_name="sparse",
|
||||||
|
index_type="SPARSE_INVERTED_INDEX",
|
||||||
|
metric_type="BM25",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add BM25 function for full-text search
|
||||||
|
bm25_function = Function(
|
||||||
|
name="text_bm25_emb",
|
||||||
|
input_field_names=["content"],
|
||||||
|
output_field_names=["sparse"],
|
||||||
|
function_type=FunctionType.BM25,
|
||||||
|
)
|
||||||
|
schema.add_function(bm25_function)
|
||||||
|
|
||||||
await asyncio.to_thread(
|
await asyncio.to_thread(
|
||||||
self.client.create_collection,
|
self.client.create_collection,
|
||||||
self.collection_name,
|
self.collection_name,
|
||||||
dimension=len(embeddings[0]),
|
schema=schema,
|
||||||
auto_id=True,
|
index_params=index_params,
|
||||||
consistency_level=self.consistency_level,
|
consistency_level=self.consistency_level,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -88,8 +142,10 @@ class MilvusIndex(EmbeddingIndex):
|
||||||
data.append(
|
data.append(
|
||||||
{
|
{
|
||||||
"chunk_id": chunk.chunk_id,
|
"chunk_id": chunk.chunk_id,
|
||||||
|
"content": chunk.content,
|
||||||
"vector": embedding,
|
"vector": embedding,
|
||||||
"chunk_content": chunk.model_dump(),
|
"chunk_content": chunk.model_dump(),
|
||||||
|
# sparse field will be handled by BM25 function automatically
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
|
@ -107,6 +163,7 @@ class MilvusIndex(EmbeddingIndex):
|
||||||
self.client.search,
|
self.client.search,
|
||||||
collection_name=self.collection_name,
|
collection_name=self.collection_name,
|
||||||
data=[embedding],
|
data=[embedding],
|
||||||
|
anns_field="vector",
|
||||||
limit=k,
|
limit=k,
|
||||||
output_fields=["*"],
|
output_fields=["*"],
|
||||||
search_params={"params": {"radius": score_threshold}},
|
search_params={"params": {"radius": score_threshold}},
|
||||||
|
@ -121,7 +178,64 @@ class MilvusIndex(EmbeddingIndex):
|
||||||
k: int,
|
k: int,
|
||||||
score_threshold: float,
|
score_threshold: float,
|
||||||
) -> QueryChunksResponse:
|
) -> QueryChunksResponse:
|
||||||
raise NotImplementedError("Keyword search is not supported in Milvus")
|
"""
|
||||||
|
Perform BM25-based keyword search using Milvus's built-in full-text search.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Use Milvus's built-in BM25 search
|
||||||
|
search_res = await asyncio.to_thread(
|
||||||
|
self.client.search,
|
||||||
|
collection_name=self.collection_name,
|
||||||
|
data=[query_string], # Raw text query
|
||||||
|
anns_field="sparse", # Use sparse field for BM25
|
||||||
|
output_fields=["chunk_content"], # Output the chunk content
|
||||||
|
limit=k,
|
||||||
|
search_params={
|
||||||
|
"params": {
|
||||||
|
"drop_ratio_search": 0.2, # Ignore low-importance terms
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
chunks = []
|
||||||
|
scores = []
|
||||||
|
for res in search_res[0]:
|
||||||
|
chunk = Chunk(**res["entity"]["chunk_content"])
|
||||||
|
chunks.append(chunk)
|
||||||
|
scores.append(res["distance"]) # BM25 score from Milvus
|
||||||
|
|
||||||
|
# Filter by score threshold
|
||||||
|
filtered_chunks = [chunk for chunk, score in zip(chunks, scores, strict=False) if score >= score_threshold]
|
||||||
|
filtered_scores = [score for score in scores if score >= score_threshold]
|
||||||
|
|
||||||
|
return QueryChunksResponse(chunks=filtered_chunks, scores=filtered_scores)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error performing BM25 search: {e}")
|
||||||
|
# Fallback to simple text search
|
||||||
|
return await self._fallback_keyword_search(query_string, k, score_threshold)
|
||||||
|
|
||||||
|
async def _fallback_keyword_search(
|
||||||
|
self,
|
||||||
|
query_string: str,
|
||||||
|
k: int,
|
||||||
|
score_threshold: float,
|
||||||
|
) -> QueryChunksResponse:
|
||||||
|
"""
|
||||||
|
Fallback to simple text search when BM25 search is not available.
|
||||||
|
"""
|
||||||
|
# Simple text search using content field
|
||||||
|
search_res = await asyncio.to_thread(
|
||||||
|
self.client.query,
|
||||||
|
collection_name=self.collection_name,
|
||||||
|
filter='content like "%{content}%"',
|
||||||
|
filter_params={"content": query_string},
|
||||||
|
output_fields=["*"],
|
||||||
|
limit=k,
|
||||||
|
)
|
||||||
|
chunks = [Chunk(**res["chunk_content"]) for res in search_res]
|
||||||
|
scores = [1.0] * len(chunks) # Simple binary score for text search
|
||||||
|
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||||
|
|
||||||
async def query_hybrid(
|
async def query_hybrid(
|
||||||
self,
|
self,
|
||||||
|
@ -247,6 +361,14 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
||||||
if not index:
|
if not index:
|
||||||
raise ValueError(f"Vector DB {vector_db_id} not found")
|
raise ValueError(f"Vector DB {vector_db_id} not found")
|
||||||
|
|
||||||
|
if params and params.get("mode") == "keyword":
|
||||||
|
# Check if this is inline Milvus (Milvus-Lite)
|
||||||
|
if hasattr(self.config, "db_path"):
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Keyword search is not supported in Milvus-Lite. "
|
||||||
|
"Please use a remote Milvus server for keyword search functionality."
|
||||||
|
)
|
||||||
|
|
||||||
return await index.query_chunks(query, params)
|
return await index.query_chunks(query, params)
|
||||||
|
|
||||||
async def _save_openai_vector_store_file(
|
async def _save_openai_vector_store_file(
|
||||||
|
|
191
tests/unit/providers/vector_io/remote/test_milvus.py
Normal file
191
tests/unit/providers/vector_io/remote/test_milvus.py
Normal file
|
@ -0,0 +1,191 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
|
||||||
|
from llama_stack.apis.vector_io import QueryChunksResponse
|
||||||
|
|
||||||
|
# Mock the entire pymilvus module
|
||||||
|
pymilvus_mock = MagicMock()
|
||||||
|
pymilvus_mock.DataType = MagicMock()
|
||||||
|
pymilvus_mock.MilvusClient = MagicMock
|
||||||
|
|
||||||
|
# Apply the mock before importing MilvusIndex
|
||||||
|
with patch.dict("sys.modules", {"pymilvus": pymilvus_mock}):
|
||||||
|
from llama_stack.providers.remote.vector_io.milvus.milvus import MilvusIndex
|
||||||
|
|
||||||
|
# This test is a unit test for the MilvusVectorIOAdapter class. This should only contain
|
||||||
|
# tests which are specific to this class. More general (API-level) tests should be placed in
|
||||||
|
# tests/integration/vector_io/
|
||||||
|
#
|
||||||
|
# How to run this test:
|
||||||
|
#
|
||||||
|
# pytest tests/unit/providers/vector_io/test_milvus.py \
|
||||||
|
# -v -s --tb=short --disable-warnings --asyncio-mode=auto
|
||||||
|
|
||||||
|
MILVUS_PROVIDER = "milvus"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def mock_milvus_client() -> MagicMock:
|
||||||
|
"""Create a mock Milvus client with common method behaviors."""
|
||||||
|
client = MagicMock()
|
||||||
|
|
||||||
|
# Mock collection operations
|
||||||
|
client.has_collection.return_value = False # Initially no collection
|
||||||
|
client.create_collection.return_value = None
|
||||||
|
client.drop_collection.return_value = None
|
||||||
|
|
||||||
|
# Mock insert operation
|
||||||
|
client.insert.return_value = {"insert_count": 10}
|
||||||
|
|
||||||
|
# Mock search operation - return mock results (data should be dict, not JSON string)
|
||||||
|
client.search.return_value = [
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"id": 0,
|
||||||
|
"distance": 0.1,
|
||||||
|
"entity": {"chunk_content": {"content": "mock chunk 1", "metadata": {"document_id": "doc1"}}},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1,
|
||||||
|
"distance": 0.2,
|
||||||
|
"entity": {"chunk_content": {"content": "mock chunk 2", "metadata": {"document_id": "doc2"}}},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
]
|
||||||
|
|
||||||
|
# Mock query operation for keyword search (data should be dict, not JSON string)
|
||||||
|
client.query.return_value = [
|
||||||
|
{
|
||||||
|
"chunk_id": "chunk1",
|
||||||
|
"chunk_content": {"content": "mock chunk 1", "metadata": {"document_id": "doc1"}},
|
||||||
|
"score": 0.9,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"chunk_id": "chunk2",
|
||||||
|
"chunk_content": {"content": "mock chunk 2", "metadata": {"document_id": "doc2"}},
|
||||||
|
"score": 0.8,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"chunk_id": "chunk3",
|
||||||
|
"chunk_content": {"content": "mock chunk 3", "metadata": {"document_id": "doc3"}},
|
||||||
|
"score": 0.7,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
return client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def milvus_index(mock_milvus_client):
|
||||||
|
"""Create a MilvusIndex with mocked client."""
|
||||||
|
index = MilvusIndex(client=mock_milvus_client, collection_name="test_collection")
|
||||||
|
yield index
|
||||||
|
# No real cleanup needed since we're using mocks
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_add_chunks(milvus_index, sample_chunks, sample_embeddings, mock_milvus_client):
|
||||||
|
# Setup: collection doesn't exist initially, then exists after creation
|
||||||
|
mock_milvus_client.has_collection.side_effect = [False, True]
|
||||||
|
|
||||||
|
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
|
||||||
|
|
||||||
|
# Verify collection was created and data was inserted
|
||||||
|
mock_milvus_client.create_collection.assert_called_once()
|
||||||
|
mock_milvus_client.insert.assert_called_once()
|
||||||
|
|
||||||
|
# Verify the insert call had the right number of chunks
|
||||||
|
insert_call = mock_milvus_client.insert.call_args
|
||||||
|
assert len(insert_call[1]["data"]) == len(sample_chunks)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_query_chunks_vector(
|
||||||
|
milvus_index, sample_chunks, sample_embeddings, embedding_dimension, mock_milvus_client
|
||||||
|
):
|
||||||
|
# Setup: Add chunks first
|
||||||
|
mock_milvus_client.has_collection.return_value = True
|
||||||
|
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
|
||||||
|
|
||||||
|
# Test vector search
|
||||||
|
query_embedding = np.random.rand(embedding_dimension).astype(np.float32)
|
||||||
|
response = await milvus_index.query_vector(query_embedding, k=2, score_threshold=0.0)
|
||||||
|
|
||||||
|
assert isinstance(response, QueryChunksResponse)
|
||||||
|
assert len(response.chunks) == 2
|
||||||
|
mock_milvus_client.search.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_query_chunks_keyword_search(milvus_index, sample_chunks, sample_embeddings, mock_milvus_client):
|
||||||
|
mock_milvus_client.has_collection.return_value = True
|
||||||
|
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
|
||||||
|
|
||||||
|
# Test keyword search
|
||||||
|
query_string = "Sentence 5"
|
||||||
|
response = await milvus_index.query_keyword(query_string=query_string, k=2, score_threshold=0.0)
|
||||||
|
|
||||||
|
assert isinstance(response, QueryChunksResponse)
|
||||||
|
assert len(response.chunks) == 2
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_bm25_fallback_to_simple_search(milvus_index, sample_chunks, sample_embeddings, mock_milvus_client):
|
||||||
|
"""Test that when BM25 search fails, the system falls back to simple text search."""
|
||||||
|
mock_milvus_client.has_collection.return_value = True
|
||||||
|
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
|
||||||
|
|
||||||
|
# Force BM25 search to fail
|
||||||
|
mock_milvus_client.search.side_effect = Exception("BM25 search not available")
|
||||||
|
|
||||||
|
# Mock simple text search results
|
||||||
|
mock_milvus_client.query.return_value = [
|
||||||
|
{
|
||||||
|
"chunk_id": "chunk1",
|
||||||
|
"chunk_content": {"content": "Python programming language", "metadata": {"document_id": "doc1"}},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"chunk_id": "chunk2",
|
||||||
|
"chunk_content": {"content": "Machine learning algorithms", "metadata": {"document_id": "doc2"}},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
# Test keyword search that should fall back to simple text search
|
||||||
|
query_string = "Python"
|
||||||
|
response = await milvus_index.query_keyword(query_string=query_string, k=3, score_threshold=0.0)
|
||||||
|
|
||||||
|
# Verify response structure
|
||||||
|
assert isinstance(response, QueryChunksResponse)
|
||||||
|
assert len(response.chunks) > 0, "Fallback search should return results"
|
||||||
|
|
||||||
|
# Verify that simple text search was used (query method called instead of search)
|
||||||
|
mock_milvus_client.query.assert_called_once()
|
||||||
|
mock_milvus_client.search.assert_called_once() # Called once but failed
|
||||||
|
|
||||||
|
# Verify the query uses parameterized filter with filter_params
|
||||||
|
query_call_args = mock_milvus_client.query.call_args
|
||||||
|
assert "filter" in query_call_args[1], "Query should include filter for text search"
|
||||||
|
assert "filter_params" in query_call_args[1], "Query should use parameterized filter"
|
||||||
|
assert query_call_args[1]["filter_params"]["content"] == "Python", "Filter params should contain the search term"
|
||||||
|
|
||||||
|
# Verify all returned chunks have score 1.0 (simple binary scoring)
|
||||||
|
assert all(score == 1.0 for score in response.scores), "Simple text search should use binary scoring"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_delete_collection(milvus_index, mock_milvus_client):
|
||||||
|
# Test collection deletion
|
||||||
|
mock_milvus_client.has_collection.return_value = True
|
||||||
|
|
||||||
|
await milvus_index.delete()
|
||||||
|
|
||||||
|
mock_milvus_client.drop_collection.assert_called_once_with(collection_name=milvus_index.collection_name)
|
Loading…
Add table
Add a link
Reference in a new issue