This commit is contained in:
Mustafa Elbehery 2025-12-02 09:58:57 +01:00 committed by GitHub
commit 154b7f568f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 547 additions and 40 deletions

View file

@ -9,19 +9,23 @@ from unittest.mock import AsyncMock, MagicMock, patch
import numpy as np
import pytest
from pymilvus import AsyncMilvusClient, connections
from llama_stack.core.storage.datatypes import KVStoreReference, SqliteKVStoreConfig
from llama_stack.core.storage.kvstore import register_kvstore_backends
from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig
from llama_stack.providers.inline.vector_io.faiss.faiss import FaissIndex, FaissVectorIOAdapter
from llama_stack.providers.inline.vector_io.milvus.config import MilvusVectorIOConfig
from llama_stack.providers.inline.vector_io.sqlite_vec import SQLiteVectorIOConfig
from llama_stack.providers.inline.vector_io.sqlite_vec.sqlite_vec import SQLiteVecIndex, SQLiteVecVectorIOAdapter
from llama_stack.providers.remote.vector_io.milvus.milvus import MilvusIndex, MilvusVectorIOAdapter
from llama_stack.providers.remote.vector_io.pgvector.config import PGVectorVectorIOConfig
from llama_stack.providers.remote.vector_io.pgvector.pgvector import PGVectorIndex, PGVectorVectorIOAdapter
from llama_stack_api import Chunk, ChunkMetadata, QueryChunksResponse, VectorStore
EMBEDDING_DIMENSION = 768
COLLECTION_PREFIX = "test_collection"
MILVUS_ALIAS = "test_milvus"
@pytest.fixture(params=["sqlite_vec", "faiss", "pgvector"])
@ -140,7 +144,7 @@ async def sqlite_vec_vec_index(embedding_dimension, tmp_path_factory):
await index.initialize()
index.db_path = db_path
yield index
index.delete()
await index.delete()
@pytest.fixture
@ -169,6 +173,48 @@ async def sqlite_vec_adapter(sqlite_vec_db_path, unique_kvstore_config, mock_inf
await adapter.shutdown()
@pytest.fixture(scope="session")
def milvus_vec_db_path(tmp_path_factory):
db_path = str(tmp_path_factory.getbasetemp() / "test_milvus.db")
return db_path
@pytest.fixture
async def milvus_vec_index(milvus_vec_db_path, embedding_dimension):
client = AsyncMilvusClient(milvus_vec_db_path)
name = f"{COLLECTION_PREFIX}_{np.random.randint(1e6)}"
connections.connect(alias=MILVUS_ALIAS, uri=milvus_vec_db_path)
index = MilvusIndex(client, name, consistency_level="Strong")
index.db_path = milvus_vec_db_path
yield index
await client.close()
@pytest.fixture
async def milvus_vec_adapter(milvus_vec_db_path, mock_inference_api):
config = MilvusVectorIOConfig(
db_path=milvus_vec_db_path,
kvstore=SqliteKVStoreConfig(),
)
adapter = MilvusVectorIOAdapter(
config=config,
inference_api=mock_inference_api,
files_api=None,
)
await adapter.initialize()
await adapter.register_vector_store(
VectorStore(
identifier=adapter.metadata_collection_name,
provider_id="test_provider",
embedding_model="test_model",
embedding_dimension=128,
)
)
yield adapter
await adapter.shutdown()
@pytest.fixture
def faiss_vec_db_path(tmp_path_factory):
db_path = str(tmp_path_factory.getbasetemp() / "test_faiss.db")

View file

@ -0,0 +1,383 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from unittest.mock import AsyncMock, MagicMock, patch
import numpy as np
import pytest
from llama_stack.apis.vector_io import QueryChunksResponse
# Mock the entire pymilvus module
pymilvus_mock = MagicMock()
pymilvus_mock.DataType = MagicMock()
pymilvus_mock.AsyncMilvusClient = MagicMock
pymilvus_mock.RRFRanker = MagicMock
pymilvus_mock.WeightedRanker = MagicMock
pymilvus_mock.AnnSearchRequest = MagicMock
# Apply the mock before importing MilvusIndex
with patch.dict("sys.modules", {"pymilvus": pymilvus_mock}):
from llama_stack.providers.remote.vector_io.milvus.milvus import MilvusIndex
# This test is a unit test for the MilvusVectorIOAdapter class. This should only contain
# tests which are specific to this class. More general (API-level) tests should be placed in
# tests/integration/vector_io/
#
# How to run this test:
#
# pytest tests/unit/providers/vector_io/test_milvus.py \
# -v -s --tb=short --disable-warnings --asyncio-mode=auto
MILVUS_PROVIDER = "milvus"
@pytest.fixture
async def mock_milvus_client() -> MagicMock:
"""Create a mock Milvus client with common method behaviors."""
client = MagicMock()
client.list_collections = AsyncMock(return_value=[]) # Initially no collections
client.create_collection = AsyncMock(return_value=None)
client.drop_collection = AsyncMock(return_value=None)
client.insert = AsyncMock(return_value={"insert_count": 10})
client.search = AsyncMock(
return_value=[
[
{
"id": 0,
"distance": 0.1,
"entity": {
"chunk_content": {
"chunk_id": "chunk1",
"content": "mock chunk 1",
"metadata": {"document_id": "doc1"},
}
},
},
{
"id": 1,
"distance": 0.2,
"entity": {
"chunk_content": {
"chunk_id": "chunk2",
"content": "mock chunk 2",
"metadata": {"document_id": "doc2"},
}
},
},
]
]
)
# Mock async query operation for keyword search (data should be dict, not JSON string)
client.query = AsyncMock(
return_value=[
{
"chunk_id": "chunk1",
"chunk_content": {"chunk_id": "chunk1", "content": "mock chunk 1", "metadata": {"document_id": "doc1"}},
"score": 0.9,
},
{
"chunk_id": "chunk2",
"chunk_content": {"chunk_id": "chunk2", "content": "mock chunk 2", "metadata": {"document_id": "doc2"}},
"score": 0.8,
},
{
"chunk_id": "chunk3",
"chunk_content": {"chunk_id": "chunk3", "content": "mock chunk 3", "metadata": {"document_id": "doc3"}},
"score": 0.7,
},
]
)
client.hybrid_search = AsyncMock(return_value=[])
client.delete = AsyncMock(return_value=None)
client.close = AsyncMock(return_value=None)
return client
@pytest.fixture
async def milvus_index(mock_milvus_client):
"""Create a MilvusIndex with mocked client."""
index = MilvusIndex(client=mock_milvus_client, collection_name="test_collection")
yield index
# No real cleanup needed since we're using mocks
async def test_add_chunks(milvus_index, sample_chunks, sample_embeddings, mock_milvus_client):
# Setup: collection doesn't exist initially, then exists after creation
mock_milvus_client.list_collections.side_effect = [[], ["test_collection"]]
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
# Verify collection was created and data was inserted
mock_milvus_client.create_collection.assert_called_once()
mock_milvus_client.insert.assert_called_once()
# Verify the insert call had the right number of chunks
insert_call = mock_milvus_client.insert.call_args
assert len(insert_call[1]["data"]) == len(sample_chunks)
async def test_query_chunks_vector(
milvus_index, sample_chunks, sample_embeddings, embedding_dimension, mock_milvus_client
):
# Setup: Add chunks first
mock_milvus_client.list_collections.return_value = ["test_collection"]
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
# Test vector search
query_embedding = np.random.rand(embedding_dimension).astype(np.float32)
response = await milvus_index.query_vector(query_embedding, k=2, score_threshold=0.0)
assert isinstance(response, QueryChunksResponse)
assert len(response.chunks) == 2
mock_milvus_client.search.assert_called_once()
async def test_query_chunks_keyword_search(milvus_index, sample_chunks, sample_embeddings, mock_milvus_client):
mock_milvus_client.list_collections.return_value = ["test_collection"]
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
# Test keyword search
query_string = "Sentence 5"
response = await milvus_index.query_keyword(query_string=query_string, k=2, score_threshold=0.0)
assert isinstance(response, QueryChunksResponse)
assert len(response.chunks) == 2
async def test_bm25_fallback_to_simple_search(milvus_index, sample_chunks, sample_embeddings, mock_milvus_client):
"""Test that when BM25 search fails, the system falls back to simple text search."""
mock_milvus_client.list_collections.return_value = ["test_collection"]
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
# Force BM25 search to fail
mock_milvus_client.search.side_effect = Exception("BM25 search not available")
# Mock simple text search results
mock_milvus_client.query.return_value = [
{
"chunk_id": "chunk1",
"chunk_content": {
"chunk_id": "chunk1",
"content": "Python programming language",
"metadata": {"document_id": "doc1"},
},
},
{
"chunk_id": "chunk2",
"chunk_content": {
"chunk_id": "chunk2",
"content": "Machine learning algorithms",
"metadata": {"document_id": "doc2"},
},
},
]
# Test keyword search that should fall back to simple text search
query_string = "Python"
response = await milvus_index.query_keyword(query_string=query_string, k=3, score_threshold=0.0)
# Verify response structure
assert isinstance(response, QueryChunksResponse)
assert len(response.chunks) > 0, "Fallback search should return results"
# Verify that simple text search was used (query method called instead of search)
mock_milvus_client.query.assert_called_once()
mock_milvus_client.search.assert_called_once() # Called once but failed
# Verify the query uses parameterized filter with filter_params
query_call_args = mock_milvus_client.query.call_args
assert "filter" in query_call_args[1], "Query should include filter for text search"
assert "filter_params" in query_call_args[1], "Query should use parameterized filter"
assert query_call_args[1]["filter_params"]["content"] == "Python", "Filter params should contain the search term"
# Verify all returned chunks have score 1.0 (simple binary scoring)
assert all(score == 1.0 for score in response.scores), "Simple text search should use binary scoring"
async def test_delete_collection(milvus_index, mock_milvus_client):
# Test collection deletion
mock_milvus_client.list_collections.return_value = ["test_collection"]
await milvus_index.delete()
mock_milvus_client.drop_collection.assert_called_once_with(collection_name=milvus_index.collection_name)
async def test_query_hybrid_search_rrf(
milvus_index, sample_chunks, sample_embeddings, embedding_dimension, mock_milvus_client
):
"""Test hybrid search with RRF reranker."""
mock_milvus_client.list_collections.return_value = ["test_collection"]
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
# Mock hybrid search results
mock_milvus_client.hybrid_search.return_value = [
[
{
"id": 0,
"distance": 0.1,
"entity": {
"chunk_content": {
"chunk_id": "chunk1",
"content": "mock chunk 1",
"metadata": {"document_id": "doc1"},
}
},
},
{
"id": 1,
"distance": 0.2,
"entity": {
"chunk_content": {
"chunk_id": "chunk2",
"content": "mock chunk 2",
"metadata": {"document_id": "doc2"},
}
},
},
]
]
# Test hybrid search with RRF reranker
query_embedding = np.random.rand(embedding_dimension).astype(np.float32)
query_string = "test query"
response = await milvus_index.query_hybrid(
embedding=query_embedding,
query_string=query_string,
k=2,
score_threshold=0.0,
reranker_type="rrf",
reranker_params={"impact_factor": 60.0},
)
assert isinstance(response, QueryChunksResponse)
assert len(response.chunks) == 2
assert len(response.scores) == 2
# Verify hybrid search was called with correct parameters
mock_milvus_client.hybrid_search.assert_called_once()
call_args = mock_milvus_client.hybrid_search.call_args
# Check that the request contains both vector and BM25 search requests
reqs = call_args[1]["reqs"]
assert len(reqs) == 2
assert reqs[0].anns_field == "vector"
assert reqs[1].anns_field == "sparse"
ranker = call_args[1]["ranker"]
assert ranker is not None
async def test_query_hybrid_search_weighted(
milvus_index, sample_chunks, sample_embeddings, embedding_dimension, mock_milvus_client
):
"""Test hybrid search with weighted reranker."""
mock_milvus_client.list_collections.return_value = ["test_collection"]
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
# Mock hybrid search results
mock_milvus_client.hybrid_search.return_value = [
[
{
"id": 0,
"distance": 0.1,
"entity": {
"chunk_content": {
"chunk_id": "chunk1",
"content": "mock chunk 1",
"metadata": {"document_id": "doc1"},
}
},
},
{
"id": 1,
"distance": 0.2,
"entity": {
"chunk_content": {
"chunk_id": "chunk2",
"content": "mock chunk 2",
"metadata": {"document_id": "doc2"},
}
},
},
]
]
# Test hybrid search with weighted reranker
query_embedding = np.random.rand(embedding_dimension).astype(np.float32)
query_string = "test query"
response = await milvus_index.query_hybrid(
embedding=query_embedding,
query_string=query_string,
k=2,
score_threshold=0.0,
reranker_type="weighted",
reranker_params={"alpha": 0.7},
)
assert isinstance(response, QueryChunksResponse)
assert len(response.chunks) == 2
assert len(response.scores) == 2
# Verify hybrid search was called with correct parameters
mock_milvus_client.hybrid_search.assert_called_once()
call_args = mock_milvus_client.hybrid_search.call_args
ranker = call_args[1]["ranker"]
assert ranker is not None
async def test_query_hybrid_search_default_rrf(
milvus_index, sample_chunks, sample_embeddings, embedding_dimension, mock_milvus_client
):
"""Test hybrid search with default RRF reranker (no reranker_type specified)."""
mock_milvus_client.list_collections.return_value = ["test_collection"]
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
# Mock hybrid search results
mock_milvus_client.hybrid_search.return_value = [
[
{
"id": 0,
"distance": 0.1,
"entity": {
"chunk_content": {
"chunk_id": "chunk1",
"content": "mock chunk 1",
"metadata": {"document_id": "doc1"},
}
},
},
]
]
# Test hybrid search with default reranker (should be RRF)
query_embedding = np.random.rand(embedding_dimension).astype(np.float32)
query_string = "test query"
response = await milvus_index.query_hybrid(
embedding=query_embedding,
query_string=query_string,
k=1,
score_threshold=0.0,
reranker_type="unknown_type", # Should default to RRF
reranker_params=None, # Should use default impact_factor
)
assert isinstance(response, QueryChunksResponse)
assert len(response.chunks) == 1
# Verify hybrid search was called with RRF reranker
mock_milvus_client.hybrid_search.assert_called_once()
call_args = mock_milvus_client.hybrid_search.call_args
ranker = call_args[1]["ranker"]
assert ranker is not None

View file

@ -48,12 +48,12 @@ async def test_initialize_index(vector_index):
async def test_add_chunks_query_vector(vector_index, sample_chunks, sample_embeddings):
vector_index.delete()
vector_index.initialize()
await vector_index.delete()
await vector_index.initialize()
await vector_index.add_chunks(sample_chunks, sample_embeddings)
resp = await vector_index.query_vector(sample_embeddings[0], k=1, score_threshold=-1)
assert resp.chunks[0].content == sample_chunks[0].content
vector_index.delete()
await vector_index.delete()
async def test_chunk_id_conflict(vector_index, sample_chunks, embedding_dimension):