mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-18 02:42:31 +00:00
feat: Implement keyword search in milvus (#2231)
Some checks failed
SqlStore Integration Tests / test-postgres (3.13) (push) Failing after 7s
Integration Tests / discover-tests (push) Successful in 8s
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 10s
Test Llama Stack Build / build-custom-container-distribution (push) Failing after 6s
Test Llama Stack Build / build-ubi9-container-distribution (push) Failing after 6s
Vector IO Integration Tests / test-matrix (3.12, remote::chromadb) (push) Failing after 11s
Vector IO Integration Tests / test-matrix (3.13, remote::chromadb) (push) Failing after 9s
Vector IO Integration Tests / test-matrix (3.12, remote::pgvector) (push) Failing after 10s
Vector IO Integration Tests / test-matrix (3.13, remote::pgvector) (push) Failing after 8s
Test Llama Stack Build / generate-matrix (push) Successful in 8s
Python Package Build Test / build (3.13) (push) Failing after 6s
Unit Tests / unit-tests (3.12) (push) Failing after 6s
Unit Tests / unit-tests (3.13) (push) Failing after 6s
Vector IO Integration Tests / test-matrix (3.12, inline::milvus) (push) Failing after 13s
Vector IO Integration Tests / test-matrix (3.13, inline::milvus) (push) Failing after 12s
Vector IO Integration Tests / test-matrix (3.12, inline::sqlite-vec) (push) Failing after 15s
Test External Providers / test-external-providers (venv) (push) Failing after 9s
Test Llama Stack Build / build-single-provider (push) Failing after 11s
Vector IO Integration Tests / test-matrix (3.13, inline::sqlite-vec) (push) Failing after 14s
SqlStore Integration Tests / test-postgres (3.12) (push) Failing after 19s
Integration Tests / test-matrix (push) Failing after 8s
Test Llama Stack Build / build (push) Failing after 5s
Python Package Build Test / build (3.12) (push) Failing after 51s
Vector IO Integration Tests / test-matrix (3.13, inline::faiss) (push) Failing after 55s
Vector IO Integration Tests / test-matrix (3.12, inline::faiss) (push) Failing after 57s
Update ReadTheDocs / update-readthedocs (push) Failing after 50s
Pre-commit / pre-commit (push) Successful in 2m9s
Some checks failed
SqlStore Integration Tests / test-postgres (3.13) (push) Failing after 7s
Integration Tests / discover-tests (push) Successful in 8s
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 10s
Test Llama Stack Build / build-custom-container-distribution (push) Failing after 6s
Test Llama Stack Build / build-ubi9-container-distribution (push) Failing after 6s
Vector IO Integration Tests / test-matrix (3.12, remote::chromadb) (push) Failing after 11s
Vector IO Integration Tests / test-matrix (3.13, remote::chromadb) (push) Failing after 9s
Vector IO Integration Tests / test-matrix (3.12, remote::pgvector) (push) Failing after 10s
Vector IO Integration Tests / test-matrix (3.13, remote::pgvector) (push) Failing after 8s
Test Llama Stack Build / generate-matrix (push) Successful in 8s
Python Package Build Test / build (3.13) (push) Failing after 6s
Unit Tests / unit-tests (3.12) (push) Failing after 6s
Unit Tests / unit-tests (3.13) (push) Failing after 6s
Vector IO Integration Tests / test-matrix (3.12, inline::milvus) (push) Failing after 13s
Vector IO Integration Tests / test-matrix (3.13, inline::milvus) (push) Failing after 12s
Vector IO Integration Tests / test-matrix (3.12, inline::sqlite-vec) (push) Failing after 15s
Test External Providers / test-external-providers (venv) (push) Failing after 9s
Test Llama Stack Build / build-single-provider (push) Failing after 11s
Vector IO Integration Tests / test-matrix (3.13, inline::sqlite-vec) (push) Failing after 14s
SqlStore Integration Tests / test-postgres (3.12) (push) Failing after 19s
Integration Tests / test-matrix (push) Failing after 8s
Test Llama Stack Build / build (push) Failing after 5s
Python Package Build Test / build (3.12) (push) Failing after 51s
Vector IO Integration Tests / test-matrix (3.13, inline::faiss) (push) Failing after 55s
Vector IO Integration Tests / test-matrix (3.12, inline::faiss) (push) Failing after 57s
Update ReadTheDocs / update-readthedocs (push) Failing after 50s
Pre-commit / pre-commit (push) Successful in 2m9s
# What does this PR do? This PR adds the keyword search implementation for Milvus. Along with the implementation for remote Milvus, the tests require us to start a Milvus containers locally. In order to verify the implementation, run: ``` pytest tests/unit/providers/vector_io/remote/test_milvus.py -v -s --tb=short --disable-warnings --asyncio-mode=auto ``` You can also test the changes using the below script: ``` #!/usr/bin/env python3 import asyncio import os import uuid from typing import List from llama_stack_client import ( Agent, AgentEventLogger, LlamaStackClient, RAGDocument ) class MilvusRAGDemo: def __init__(self, base_url: str = "http://localhost:8321/"): self.client = LlamaStackClient(base_url=base_url) self.vector_db_id = f"milvus_rag_demo_{uuid.uuid4().hex[:8]}" self.model_id = None self.embedding_model_id = None self.embedding_dimension = None def setup_models(self): """Get available models and select appropriate ones for LLM and embeddings.""" models = self.client.models.list() # Select embedding model embedding_models = [m for m in models if m.model_type == "embedding"] if not embedding_models: raise ValueError("No embedding models found") self.embedding_model_id = embedding_models[0].identifier self.embedding_dimension = embedding_models[0].metadata["embedding_dimension"] def register_vector_db(self): print(f"Registering Milvus vector database: {self.vector_db_id}") response = self.client.vector_dbs.register( vector_db_id=self.vector_db_id, embedding_model=self.embedding_model_id, embedding_dimension=self.embedding_dimension, provider_id="milvus-remote", # Use remote Milvus ) print(f"Vector database registered successfully") return response def insert_documents(self): """Insert sample documents into the vector database.""" print("\nInserting sample documents...") # Sample documents about different topics documents = [ RAGDocument( document_id="ai_ml_basics", content=""" Artificial Intelligence (AI) and Machine Learning (ML) are transforming the world. AI refers to the simulation of human intelligence in machines, while ML is a subset of AI that enables computers to learn and improve from experience without being explicitly programmed. Deep learning, a subset of ML, uses neural networks with multiple layers to process complex patterns in data. Key concepts in AI/ML include: - Supervised Learning: Training with labeled data - Unsupervised Learning: Finding patterns in unlabeled data - Reinforcement Learning: Learning through trial and error - Neural Networks: Computing systems inspired by biological brains """, mime_type="text/plain", metadata={"topic": "technology", "category": "ai_ml"}, ), ] # Insert documents with chunking self.client.tool_runtime.rag_tool.insert( documents=documents, vector_db_id=self.vector_db_id, chunk_size_in_tokens=200, # Smaller chunks for better granularity ) print(f"Inserted {len(documents)} documents with chunking") def test_keyword_search(self): """Test keyword-based search using BM25.""" queries = [ "neural networks", "Python frameworks", "data cleaning", ] for query in queries: response = self.client.vector_io.query( vector_db_id=self.vector_db_id, query=query, params={ "mode": "keyword", # Keyword search "max_chunks": 3, "score_threshold": 0.0, } ) for i, (chunk, score) in enumerate(zip(response.chunks, response.scores)): print(f" {i+1}. Score: {score:.4f}") print(f" Content: {chunk.content[:100]}...") print(f" Metadata: {chunk.metadata}") def run_demo(self): try: self.setup_models() self.register_vector_db() self.insert_documents() self.test_keyword_search() except Exception as e: print(f"Error during demo: {e}") raise def main(): """Main function to run the demo.""" # Check if Llama Stack server is running demo = MilvusRAGDemo() try: demo.run_demo() except Exception as e: print(f"Demo failed: {e}") if __name__ == "__main__": main() ``` [//]: # (## Documentation) --------- Signed-off-by: Varsha Prasad Narsing <varshaprasad96@gmail.com>
This commit is contained in:
parent
33f0d83ad3
commit
4ae5656c2f
4 changed files with 331 additions and 8 deletions
|
@ -114,7 +114,7 @@ For more details on TLS configuration, refer to the [TLS setup guide](https://mi
|
|||
| `uri` | `<class 'str'>` | No | PydanticUndefined | The URI of the Milvus server |
|
||||
| `token` | `str \| None` | No | PydanticUndefined | The token of the Milvus server |
|
||||
| `consistency_level` | `<class 'str'>` | No | Strong | The consistency level of the Milvus server |
|
||||
| `kvstore` | `utils.kvstore.config.RedisKVStoreConfig \| utils.kvstore.config.SqliteKVStoreConfig \| utils.kvstore.config.PostgresKVStoreConfig \| utils.kvstore.config.MongoDBKVStoreConfig, annotation=NoneType, required=False, default='sqlite', discriminator='type'` | No | | Config for KV store backend (SQLite only for now) |
|
||||
| `kvstore` | `utils.kvstore.config.RedisKVStoreConfig \| utils.kvstore.config.SqliteKVStoreConfig \| utils.kvstore.config.PostgresKVStoreConfig \| utils.kvstore.config.MongoDBKVStoreConfig` | No | sqlite | Config for KV store backend |
|
||||
| `config` | `dict` | No | {} | This configuration allows additional fields to be passed through to the underlying Milvus client. See the [Milvus](https://milvus.io/docs/install-overview.md) documentation for more details about Milvus in general. |
|
||||
|
||||
> **Note**: This configuration class accepts additional fields beyond those listed above. You can pass any additional configuration options that will be forwarded to the underlying provider.
|
||||
|
@ -124,6 +124,9 @@ For more details on TLS configuration, refer to the [TLS setup guide](https://mi
|
|||
```yaml
|
||||
uri: ${env.MILVUS_ENDPOINT}
|
||||
token: ${env.MILVUS_TOKEN}
|
||||
kvstore:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/dummy}/milvus_remote_registry.db
|
||||
|
||||
```
|
||||
|
||||
|
|
|
@ -8,7 +8,7 @@ from typing import Any
|
|||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from llama_stack.providers.utils.kvstore.config import KVStoreConfig
|
||||
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
|
@ -17,7 +17,7 @@ class MilvusVectorIOConfig(BaseModel):
|
|||
uri: str = Field(description="The URI of the Milvus server")
|
||||
token: str | None = Field(description="The token of the Milvus server")
|
||||
consistency_level: str = Field(description="The consistency level of the Milvus server", default="Strong")
|
||||
kvstore: KVStoreConfig | None = Field(description="Config for KV store backend (SQLite only for now)", default=None)
|
||||
kvstore: KVStoreConfig = Field(description="Config for KV store backend")
|
||||
|
||||
# This configuration allows additional fields to be passed through to the underlying Milvus client.
|
||||
# See the [Milvus](https://milvus.io/docs/install-overview.md) documentation for more details about Milvus in general.
|
||||
|
@ -25,4 +25,11 @@ class MilvusVectorIOConfig(BaseModel):
|
|||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
|
||||
return {"uri": "${env.MILVUS_ENDPOINT}", "token": "${env.MILVUS_TOKEN}"}
|
||||
return {
|
||||
"uri": "${env.MILVUS_ENDPOINT}",
|
||||
"token": "${env.MILVUS_TOKEN}",
|
||||
"kvstore": SqliteKVStoreConfig.sample_run_config(
|
||||
__distro_dir__=__distro_dir__,
|
||||
db_name="milvus_remote_registry.db",
|
||||
),
|
||||
}
|
||||
|
|
|
@ -12,7 +12,7 @@ import re
|
|||
from typing import Any
|
||||
|
||||
from numpy.typing import NDArray
|
||||
from pymilvus import DataType, MilvusClient
|
||||
from pymilvus import DataType, Function, FunctionType, MilvusClient
|
||||
|
||||
from llama_stack.apis.files.files import Files
|
||||
from llama_stack.apis.inference import Inference, InterleavedContent
|
||||
|
@ -74,12 +74,66 @@ class MilvusIndex(EmbeddingIndex):
|
|||
assert len(chunks) == len(embeddings), (
|
||||
f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}"
|
||||
)
|
||||
|
||||
if not await asyncio.to_thread(self.client.has_collection, self.collection_name):
|
||||
logger.info(f"Creating new collection {self.collection_name} with nullable sparse field")
|
||||
# Create schema for vector search
|
||||
schema = self.client.create_schema()
|
||||
schema.add_field(
|
||||
field_name="chunk_id",
|
||||
datatype=DataType.VARCHAR,
|
||||
is_primary=True,
|
||||
max_length=100,
|
||||
)
|
||||
schema.add_field(
|
||||
field_name="content",
|
||||
datatype=DataType.VARCHAR,
|
||||
max_length=65535,
|
||||
enable_analyzer=True, # Enable text analysis for BM25
|
||||
)
|
||||
schema.add_field(
|
||||
field_name="vector",
|
||||
datatype=DataType.FLOAT_VECTOR,
|
||||
dim=len(embeddings[0]),
|
||||
)
|
||||
schema.add_field(
|
||||
field_name="chunk_content",
|
||||
datatype=DataType.JSON,
|
||||
)
|
||||
# Add sparse vector field for BM25 (required by the function)
|
||||
schema.add_field(
|
||||
field_name="sparse",
|
||||
datatype=DataType.SPARSE_FLOAT_VECTOR,
|
||||
)
|
||||
|
||||
# Create indexes
|
||||
index_params = self.client.prepare_index_params()
|
||||
index_params.add_index(
|
||||
field_name="vector",
|
||||
index_type="FLAT",
|
||||
metric_type="COSINE",
|
||||
)
|
||||
# Add index for sparse field (required by BM25 function)
|
||||
index_params.add_index(
|
||||
field_name="sparse",
|
||||
index_type="SPARSE_INVERTED_INDEX",
|
||||
metric_type="BM25",
|
||||
)
|
||||
|
||||
# Add BM25 function for full-text search
|
||||
bm25_function = Function(
|
||||
name="text_bm25_emb",
|
||||
input_field_names=["content"],
|
||||
output_field_names=["sparse"],
|
||||
function_type=FunctionType.BM25,
|
||||
)
|
||||
schema.add_function(bm25_function)
|
||||
|
||||
await asyncio.to_thread(
|
||||
self.client.create_collection,
|
||||
self.collection_name,
|
||||
dimension=len(embeddings[0]),
|
||||
auto_id=True,
|
||||
schema=schema,
|
||||
index_params=index_params,
|
||||
consistency_level=self.consistency_level,
|
||||
)
|
||||
|
||||
|
@ -88,8 +142,10 @@ class MilvusIndex(EmbeddingIndex):
|
|||
data.append(
|
||||
{
|
||||
"chunk_id": chunk.chunk_id,
|
||||
"content": chunk.content,
|
||||
"vector": embedding,
|
||||
"chunk_content": chunk.model_dump(),
|
||||
# sparse field will be handled by BM25 function automatically
|
||||
}
|
||||
)
|
||||
try:
|
||||
|
@ -107,6 +163,7 @@ class MilvusIndex(EmbeddingIndex):
|
|||
self.client.search,
|
||||
collection_name=self.collection_name,
|
||||
data=[embedding],
|
||||
anns_field="vector",
|
||||
limit=k,
|
||||
output_fields=["*"],
|
||||
search_params={"params": {"radius": score_threshold}},
|
||||
|
@ -121,7 +178,64 @@ class MilvusIndex(EmbeddingIndex):
|
|||
k: int,
|
||||
score_threshold: float,
|
||||
) -> QueryChunksResponse:
|
||||
raise NotImplementedError("Keyword search is not supported in Milvus")
|
||||
"""
|
||||
Perform BM25-based keyword search using Milvus's built-in full-text search.
|
||||
"""
|
||||
try:
|
||||
# Use Milvus's built-in BM25 search
|
||||
search_res = await asyncio.to_thread(
|
||||
self.client.search,
|
||||
collection_name=self.collection_name,
|
||||
data=[query_string], # Raw text query
|
||||
anns_field="sparse", # Use sparse field for BM25
|
||||
output_fields=["chunk_content"], # Output the chunk content
|
||||
limit=k,
|
||||
search_params={
|
||||
"params": {
|
||||
"drop_ratio_search": 0.2, # Ignore low-importance terms
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
chunks = []
|
||||
scores = []
|
||||
for res in search_res[0]:
|
||||
chunk = Chunk(**res["entity"]["chunk_content"])
|
||||
chunks.append(chunk)
|
||||
scores.append(res["distance"]) # BM25 score from Milvus
|
||||
|
||||
# Filter by score threshold
|
||||
filtered_chunks = [chunk for chunk, score in zip(chunks, scores, strict=False) if score >= score_threshold]
|
||||
filtered_scores = [score for score in scores if score >= score_threshold]
|
||||
|
||||
return QueryChunksResponse(chunks=filtered_chunks, scores=filtered_scores)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error performing BM25 search: {e}")
|
||||
# Fallback to simple text search
|
||||
return await self._fallback_keyword_search(query_string, k, score_threshold)
|
||||
|
||||
async def _fallback_keyword_search(
|
||||
self,
|
||||
query_string: str,
|
||||
k: int,
|
||||
score_threshold: float,
|
||||
) -> QueryChunksResponse:
|
||||
"""
|
||||
Fallback to simple text search when BM25 search is not available.
|
||||
"""
|
||||
# Simple text search using content field
|
||||
search_res = await asyncio.to_thread(
|
||||
self.client.query,
|
||||
collection_name=self.collection_name,
|
||||
filter='content like "%{content}%"',
|
||||
filter_params={"content": query_string},
|
||||
output_fields=["*"],
|
||||
limit=k,
|
||||
)
|
||||
chunks = [Chunk(**res["chunk_content"]) for res in search_res]
|
||||
scores = [1.0] * len(chunks) # Simple binary score for text search
|
||||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||
|
||||
async def query_hybrid(
|
||||
self,
|
||||
|
@ -247,6 +361,14 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
|||
if not index:
|
||||
raise ValueError(f"Vector DB {vector_db_id} not found")
|
||||
|
||||
if params and params.get("mode") == "keyword":
|
||||
# Check if this is inline Milvus (Milvus-Lite)
|
||||
if hasattr(self.config, "db_path"):
|
||||
raise NotImplementedError(
|
||||
"Keyword search is not supported in Milvus-Lite. "
|
||||
"Please use a remote Milvus server for keyword search functionality."
|
||||
)
|
||||
|
||||
return await index.query_chunks(query, params)
|
||||
|
||||
async def _save_openai_vector_store_file(
|
||||
|
|
191
tests/unit/providers/vector_io/remote/test_milvus.py
Normal file
191
tests/unit/providers/vector_io/remote/test_milvus.py
Normal file
|
@ -0,0 +1,191 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from llama_stack.apis.vector_io import QueryChunksResponse
|
||||
|
||||
# Mock the entire pymilvus module
|
||||
pymilvus_mock = MagicMock()
|
||||
pymilvus_mock.DataType = MagicMock()
|
||||
pymilvus_mock.MilvusClient = MagicMock
|
||||
|
||||
# Apply the mock before importing MilvusIndex
|
||||
with patch.dict("sys.modules", {"pymilvus": pymilvus_mock}):
|
||||
from llama_stack.providers.remote.vector_io.milvus.milvus import MilvusIndex
|
||||
|
||||
# This test is a unit test for the MilvusVectorIOAdapter class. This should only contain
|
||||
# tests which are specific to this class. More general (API-level) tests should be placed in
|
||||
# tests/integration/vector_io/
|
||||
#
|
||||
# How to run this test:
|
||||
#
|
||||
# pytest tests/unit/providers/vector_io/test_milvus.py \
|
||||
# -v -s --tb=short --disable-warnings --asyncio-mode=auto
|
||||
|
||||
MILVUS_PROVIDER = "milvus"
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def mock_milvus_client() -> MagicMock:
|
||||
"""Create a mock Milvus client with common method behaviors."""
|
||||
client = MagicMock()
|
||||
|
||||
# Mock collection operations
|
||||
client.has_collection.return_value = False # Initially no collection
|
||||
client.create_collection.return_value = None
|
||||
client.drop_collection.return_value = None
|
||||
|
||||
# Mock insert operation
|
||||
client.insert.return_value = {"insert_count": 10}
|
||||
|
||||
# Mock search operation - return mock results (data should be dict, not JSON string)
|
||||
client.search.return_value = [
|
||||
[
|
||||
{
|
||||
"id": 0,
|
||||
"distance": 0.1,
|
||||
"entity": {"chunk_content": {"content": "mock chunk 1", "metadata": {"document_id": "doc1"}}},
|
||||
},
|
||||
{
|
||||
"id": 1,
|
||||
"distance": 0.2,
|
||||
"entity": {"chunk_content": {"content": "mock chunk 2", "metadata": {"document_id": "doc2"}}},
|
||||
},
|
||||
]
|
||||
]
|
||||
|
||||
# Mock query operation for keyword search (data should be dict, not JSON string)
|
||||
client.query.return_value = [
|
||||
{
|
||||
"chunk_id": "chunk1",
|
||||
"chunk_content": {"content": "mock chunk 1", "metadata": {"document_id": "doc1"}},
|
||||
"score": 0.9,
|
||||
},
|
||||
{
|
||||
"chunk_id": "chunk2",
|
||||
"chunk_content": {"content": "mock chunk 2", "metadata": {"document_id": "doc2"}},
|
||||
"score": 0.8,
|
||||
},
|
||||
{
|
||||
"chunk_id": "chunk3",
|
||||
"chunk_content": {"content": "mock chunk 3", "metadata": {"document_id": "doc3"}},
|
||||
"score": 0.7,
|
||||
},
|
||||
]
|
||||
|
||||
return client
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def milvus_index(mock_milvus_client):
|
||||
"""Create a MilvusIndex with mocked client."""
|
||||
index = MilvusIndex(client=mock_milvus_client, collection_name="test_collection")
|
||||
yield index
|
||||
# No real cleanup needed since we're using mocks
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_chunks(milvus_index, sample_chunks, sample_embeddings, mock_milvus_client):
|
||||
# Setup: collection doesn't exist initially, then exists after creation
|
||||
mock_milvus_client.has_collection.side_effect = [False, True]
|
||||
|
||||
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
|
||||
|
||||
# Verify collection was created and data was inserted
|
||||
mock_milvus_client.create_collection.assert_called_once()
|
||||
mock_milvus_client.insert.assert_called_once()
|
||||
|
||||
# Verify the insert call had the right number of chunks
|
||||
insert_call = mock_milvus_client.insert.call_args
|
||||
assert len(insert_call[1]["data"]) == len(sample_chunks)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_chunks_vector(
|
||||
milvus_index, sample_chunks, sample_embeddings, embedding_dimension, mock_milvus_client
|
||||
):
|
||||
# Setup: Add chunks first
|
||||
mock_milvus_client.has_collection.return_value = True
|
||||
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
|
||||
|
||||
# Test vector search
|
||||
query_embedding = np.random.rand(embedding_dimension).astype(np.float32)
|
||||
response = await milvus_index.query_vector(query_embedding, k=2, score_threshold=0.0)
|
||||
|
||||
assert isinstance(response, QueryChunksResponse)
|
||||
assert len(response.chunks) == 2
|
||||
mock_milvus_client.search.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_chunks_keyword_search(milvus_index, sample_chunks, sample_embeddings, mock_milvus_client):
|
||||
mock_milvus_client.has_collection.return_value = True
|
||||
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
|
||||
|
||||
# Test keyword search
|
||||
query_string = "Sentence 5"
|
||||
response = await milvus_index.query_keyword(query_string=query_string, k=2, score_threshold=0.0)
|
||||
|
||||
assert isinstance(response, QueryChunksResponse)
|
||||
assert len(response.chunks) == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bm25_fallback_to_simple_search(milvus_index, sample_chunks, sample_embeddings, mock_milvus_client):
|
||||
"""Test that when BM25 search fails, the system falls back to simple text search."""
|
||||
mock_milvus_client.has_collection.return_value = True
|
||||
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
|
||||
|
||||
# Force BM25 search to fail
|
||||
mock_milvus_client.search.side_effect = Exception("BM25 search not available")
|
||||
|
||||
# Mock simple text search results
|
||||
mock_milvus_client.query.return_value = [
|
||||
{
|
||||
"chunk_id": "chunk1",
|
||||
"chunk_content": {"content": "Python programming language", "metadata": {"document_id": "doc1"}},
|
||||
},
|
||||
{
|
||||
"chunk_id": "chunk2",
|
||||
"chunk_content": {"content": "Machine learning algorithms", "metadata": {"document_id": "doc2"}},
|
||||
},
|
||||
]
|
||||
|
||||
# Test keyword search that should fall back to simple text search
|
||||
query_string = "Python"
|
||||
response = await milvus_index.query_keyword(query_string=query_string, k=3, score_threshold=0.0)
|
||||
|
||||
# Verify response structure
|
||||
assert isinstance(response, QueryChunksResponse)
|
||||
assert len(response.chunks) > 0, "Fallback search should return results"
|
||||
|
||||
# Verify that simple text search was used (query method called instead of search)
|
||||
mock_milvus_client.query.assert_called_once()
|
||||
mock_milvus_client.search.assert_called_once() # Called once but failed
|
||||
|
||||
# Verify the query uses parameterized filter with filter_params
|
||||
query_call_args = mock_milvus_client.query.call_args
|
||||
assert "filter" in query_call_args[1], "Query should include filter for text search"
|
||||
assert "filter_params" in query_call_args[1], "Query should use parameterized filter"
|
||||
assert query_call_args[1]["filter_params"]["content"] == "Python", "Filter params should contain the search term"
|
||||
|
||||
# Verify all returned chunks have score 1.0 (simple binary scoring)
|
||||
assert all(score == 1.0 for score in response.scores), "Simple text search should use binary scoring"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_collection(milvus_index, mock_milvus_client):
|
||||
# Test collection deletion
|
||||
mock_milvus_client.has_collection.return_value = True
|
||||
|
||||
await milvus_index.delete()
|
||||
|
||||
mock_milvus_client.drop_collection.assert_called_once_with(collection_name=milvus_index.collection_name)
|
Loading…
Add table
Add a link
Reference in a new issue