This commit is contained in:
Mustafa Elbehery 2025-10-03 14:11:23 +02:00 committed by GitHub
commit cfe5ac498f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 151 additions and 86 deletions

View file

@ -10,7 +10,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
import numpy as np
import pytest
from chromadb import PersistentClient
from pymilvus import MilvusClient, connections
from pymilvus import AsyncMilvusClient, connections
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import Chunk, ChunkMetadata, QueryChunksResponse
@ -141,7 +141,7 @@ async def sqlite_vec_vec_index(embedding_dimension, tmp_path_factory):
await index.initialize()
index.db_path = db_path
yield index
index.delete()
await index.delete()
@pytest.fixture
@ -178,13 +178,15 @@ def milvus_vec_db_path(tmp_path_factory):
@pytest.fixture
async def milvus_vec_index(milvus_vec_db_path, embedding_dimension):
client = MilvusClient(milvus_vec_db_path)
client = AsyncMilvusClient(milvus_vec_db_path)
name = f"{COLLECTION_PREFIX}_{np.random.randint(1e6)}"
connections.connect(alias=MILVUS_ALIAS, uri=milvus_vec_db_path)
index = MilvusIndex(client, name, consistency_level="Strong")
index.db_path = milvus_vec_db_path
yield index
await client.close()
@pytest.fixture
async def milvus_vec_adapter(milvus_vec_db_path, mock_inference_api):

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from unittest.mock import MagicMock, patch
from unittest.mock import AsyncMock, MagicMock, patch
import numpy as np
import pytest
@ -14,7 +14,7 @@ from llama_stack.apis.vector_io import QueryChunksResponse
# Mock the entire pymilvus module
pymilvus_mock = MagicMock()
pymilvus_mock.DataType = MagicMock()
pymilvus_mock.MilvusClient = MagicMock
pymilvus_mock.AsyncMilvusClient = MagicMock
pymilvus_mock.RRFRanker = MagicMock
pymilvus_mock.WeightedRanker = MagicMock
pymilvus_mock.AnnSearchRequest = MagicMock
@ -40,48 +40,55 @@ async def mock_milvus_client() -> MagicMock:
"""Create a mock Milvus client with common method behaviors."""
client = MagicMock()
# Mock collection operations
client.has_collection.return_value = False # Initially no collection
client.create_collection.return_value = None
client.drop_collection.return_value = None
client.list_collections = AsyncMock(return_value=[]) # Initially no collections
client.create_collection = AsyncMock(return_value=None)
client.drop_collection = AsyncMock(return_value=None)
# Mock insert operation
client.insert.return_value = {"insert_count": 10}
client.insert = AsyncMock(return_value={"insert_count": 10})
# Mock search operation - return mock results (data should be dict, not JSON string)
client.search.return_value = [
[
client.search = AsyncMock(
return_value=[
[
{
"id": 0,
"distance": 0.1,
"entity": {"chunk_content": {"content": "mock chunk 1", "metadata": {"document_id": "doc1"}}},
},
{
"id": 1,
"distance": 0.2,
"entity": {"chunk_content": {"content": "mock chunk 2", "metadata": {"document_id": "doc2"}}},
},
]
]
)
# Mock async query operation for keyword search (data should be dict, not JSON string)
client.query = AsyncMock(
return_value=[
{
"id": 0,
"distance": 0.1,
"entity": {"chunk_content": {"content": "mock chunk 1", "metadata": {"document_id": "doc1"}}},
"chunk_id": "chunk1",
"chunk_content": {"content": "mock chunk 1", "metadata": {"document_id": "doc1"}},
"score": 0.9,
},
{
"id": 1,
"distance": 0.2,
"entity": {"chunk_content": {"content": "mock chunk 2", "metadata": {"document_id": "doc2"}}},
"chunk_id": "chunk2",
"chunk_content": {"content": "mock chunk 2", "metadata": {"document_id": "doc2"}},
"score": 0.8,
},
{
"chunk_id": "chunk3",
"chunk_content": {"content": "mock chunk 3", "metadata": {"document_id": "doc3"}},
"score": 0.7,
},
]
]
)
# Mock query operation for keyword search (data should be dict, not JSON string)
client.query.return_value = [
{
"chunk_id": "chunk1",
"chunk_content": {"content": "mock chunk 1", "metadata": {"document_id": "doc1"}},
"score": 0.9,
},
{
"chunk_id": "chunk2",
"chunk_content": {"content": "mock chunk 2", "metadata": {"document_id": "doc2"}},
"score": 0.8,
},
{
"chunk_id": "chunk3",
"chunk_content": {"content": "mock chunk 3", "metadata": {"document_id": "doc3"}},
"score": 0.7,
},
]
client.hybrid_search = AsyncMock(return_value=[])
client.delete = AsyncMock(return_value=None)
client.close = AsyncMock(return_value=None)
return client
@ -96,7 +103,7 @@ async def milvus_index(mock_milvus_client):
async def test_add_chunks(milvus_index, sample_chunks, sample_embeddings, mock_milvus_client):
# Setup: collection doesn't exist initially, then exists after creation
mock_milvus_client.has_collection.side_effect = [False, True]
mock_milvus_client.list_collections.side_effect = [[], ["test_collection"]]
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
@ -113,7 +120,7 @@ async def test_query_chunks_vector(
milvus_index, sample_chunks, sample_embeddings, embedding_dimension, mock_milvus_client
):
# Setup: Add chunks first
mock_milvus_client.has_collection.return_value = True
mock_milvus_client.list_collections.return_value = ["test_collection"]
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
# Test vector search
@ -126,7 +133,7 @@ async def test_query_chunks_vector(
async def test_query_chunks_keyword_search(milvus_index, sample_chunks, sample_embeddings, mock_milvus_client):
mock_milvus_client.has_collection.return_value = True
mock_milvus_client.list_collections.return_value = ["test_collection"]
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
# Test keyword search
@ -139,7 +146,7 @@ async def test_query_chunks_keyword_search(milvus_index, sample_chunks, sample_e
async def test_bm25_fallback_to_simple_search(milvus_index, sample_chunks, sample_embeddings, mock_milvus_client):
"""Test that when BM25 search fails, the system falls back to simple text search."""
mock_milvus_client.has_collection.return_value = True
mock_milvus_client.list_collections.return_value = ["test_collection"]
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
# Force BM25 search to fail
@ -181,7 +188,7 @@ async def test_bm25_fallback_to_simple_search(milvus_index, sample_chunks, sampl
async def test_delete_collection(milvus_index, mock_milvus_client):
# Test collection deletion
mock_milvus_client.has_collection.return_value = True
mock_milvus_client.list_collections.return_value = ["test_collection"]
await milvus_index.delete()
@ -192,7 +199,7 @@ async def test_query_hybrid_search_rrf(
milvus_index, sample_chunks, sample_embeddings, embedding_dimension, mock_milvus_client
):
"""Test hybrid search with RRF reranker."""
mock_milvus_client.has_collection.return_value = True
mock_milvus_client.list_collections.return_value = ["test_collection"]
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
# Mock hybrid search results
@ -244,7 +251,7 @@ async def test_query_hybrid_search_weighted(
milvus_index, sample_chunks, sample_embeddings, embedding_dimension, mock_milvus_client
):
"""Test hybrid search with weighted reranker."""
mock_milvus_client.has_collection.return_value = True
mock_milvus_client.list_collections.return_value = ["test_collection"]
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
# Mock hybrid search results
@ -290,7 +297,7 @@ async def test_query_hybrid_search_default_rrf(
milvus_index, sample_chunks, sample_embeddings, embedding_dimension, mock_milvus_client
):
"""Test hybrid search with default RRF reranker (no reranker_type specified)."""
mock_milvus_client.has_collection.return_value = True
mock_milvus_client.list_collections.return_value = ["test_collection"]
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
# Mock hybrid search results

View file

@ -30,12 +30,12 @@ async def test_initialize_index(vector_index):
async def test_add_chunks_query_vector(vector_index, sample_chunks, sample_embeddings):
vector_index.delete()
vector_index.initialize()
await vector_index.delete()
await vector_index.initialize()
await vector_index.add_chunks(sample_chunks, sample_embeddings)
resp = await vector_index.query_vector(sample_embeddings[0], k=1, score_threshold=-1)
assert resp.chunks[0].content == sample_chunks[0].content
vector_index.delete()
await vector_index.delete()
async def test_chunk_id_conflict(vector_index, sample_chunks, embedding_dimension):