This commit is contained in:
Mustafa Elbehery 2025-12-02 09:58:57 +01:00 committed by GitHub
commit 154b7f568f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 547 additions and 40 deletions

View file

@ -125,6 +125,8 @@ unit = [
"together",
"coverage",
"moto[s3]>=5.1.10",
"pymilvus>=2.6.1",
"milvus-lite>=2.5.0",
]
# These are the core dependencies required for running integration tests. They are shared across all
# providers. If a provider requires additional dependencies, please add them to your environment

View file

@ -4,12 +4,11 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import os
from typing import Any
from numpy.typing import NDArray
from pymilvus import AnnSearchRequest, DataType, Function, FunctionType, MilvusClient, RRFRanker, WeightedRanker
from pymilvus import AnnSearchRequest, AsyncMilvusClient, DataType, Function, FunctionType, RRFRanker, WeightedRanker
from llama_stack.core.storage.kvstore import kvstore_impl
from llama_stack.log import get_logger
@ -49,12 +48,18 @@ OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX = f"openai_vector_stores_files_conten
class MilvusIndex(EmbeddingIndex):
def __init__(
self, client: MilvusClient, collection_name: str, consistency_level="Strong", kvstore: KVStore | None = None
self,
client: AsyncMilvusClient,
collection_name: str,
consistency_level="Strong",
kvstore: KVStore | None = None,
parent_adapter=None,
):
self.client = client
self.collection_name = sanitize_collection_name(collection_name)
self.consistency_level = consistency_level
self.kvstore = kvstore
self._parent_adapter = parent_adapter
async def initialize(self):
# MilvusIndex does not require explicit initialization
@ -62,15 +67,39 @@ class MilvusIndex(EmbeddingIndex):
pass
async def delete(self):
if await asyncio.to_thread(self.client.has_collection, self.collection_name):
await asyncio.to_thread(self.client.drop_collection, collection_name=self.collection_name)
try:
collections = await self.client.list_collections()
if self.collection_name in collections:
await self.client.drop_collection(collection_name=self.collection_name)
except Exception as e:
logger.warning(f"Failed to check or delete collection {self.collection_name}: {e}")
async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray):
assert len(chunks) == len(embeddings), (
f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}"
)
if not await asyncio.to_thread(self.client.has_collection, self.collection_name):
try:
collections = await self.client.list_collections()
collection_exists = self.collection_name in collections
except Exception as e:
logger.error(f"Failed to check collection existence: {self.collection_name} ({e})")
# If it's an event loop issue, try to recreate the client
if "attached to a different loop" in str(e):
logger.warning("Recreating client due to event loop issue")
if hasattr(self, "_parent_adapter"):
await self._parent_adapter._recreate_client()
collections = await self.client.list_collections()
collection_exists = self.collection_name in collections
else:
# Assume collection doesn't exist if we can't check
collection_exists = False
else:
# Assume collection doesn't exist if we can't check due to other issues
collection_exists = False
if not collection_exists:
logger.info(f"Creating new collection {self.collection_name} with nullable sparse field")
# Create schema for vector search
schema = self.client.create_schema()
@ -101,13 +130,16 @@ class MilvusIndex(EmbeddingIndex):
)
schema.add_function(bm25_function)
await asyncio.to_thread(
self.client.create_collection,
self.collection_name,
schema=schema,
index_params=index_params,
consistency_level=self.consistency_level,
)
try:
await self.client.create_collection(
self.collection_name,
schema=schema,
index_params=index_params,
consistency_level=self.consistency_level,
)
except Exception as e:
logger.error(f"Failed to create collection {self.collection_name}: {e}")
raise e
data = []
for chunk, embedding in zip(chunks, embeddings, strict=False):
@ -121,14 +153,16 @@ class MilvusIndex(EmbeddingIndex):
}
)
try:
await asyncio.to_thread(self.client.insert, self.collection_name, data=data)
await self.client.insert(
self.collection_name,
data=data,
)
except Exception as e:
logger.error(f"Error inserting chunks into Milvus collection {self.collection_name}: {e}")
raise e
async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
search_res = await asyncio.to_thread(
self.client.search,
search_res = await self.client.search(
collection_name=self.collection_name,
data=[embedding],
anns_field="vector",
@ -146,8 +180,7 @@ class MilvusIndex(EmbeddingIndex):
"""
try:
# Use Milvus's built-in BM25 search
search_res = await asyncio.to_thread(
self.client.search,
search_res = await self.client.search(
collection_name=self.collection_name,
data=[query_string], # Raw text query
anns_field="sparse", # Use sparse field for BM25
@ -183,8 +216,7 @@ class MilvusIndex(EmbeddingIndex):
Fallback to simple text search when BM25 search is not available.
"""
# Simple text search using content field
search_res = await asyncio.to_thread(
self.client.query,
search_res = await self.client.query(
collection_name=self.collection_name,
filter='content like "%{content}%"',
filter_params={"content": query_string},
@ -231,8 +263,7 @@ class MilvusIndex(EmbeddingIndex):
impact_factor = (reranker_params or {}).get("impact_factor", 60.0)
rerank = RRFRanker(impact_factor)
search_res = await asyncio.to_thread(
self.client.hybrid_search,
search_res = await self.client.hybrid_search(
collection_name=self.collection_name,
reqs=search_requests,
ranker=rerank,
@ -258,9 +289,7 @@ class MilvusIndex(EmbeddingIndex):
try:
# Use IN clause with square brackets and single quotes for VARCHAR field
chunk_ids_str = ", ".join(f"'{chunk_id}'" for chunk_id in chunk_ids)
await asyncio.to_thread(
self.client.delete, collection_name=self.collection_name, filter=f"chunk_id in [{chunk_ids_str}]"
)
await self.client.delete(collection_name=self.collection_name, filter=f"chunk_id in [{chunk_ids_str}]")
except Exception as e:
logger.error(f"Error deleting chunks from Milvus collection {self.collection_name}: {e}")
raise
@ -283,6 +312,15 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoc
async def initialize(self) -> None:
self.kvstore = await kvstore_impl(self.config.persistence)
if isinstance(self.config, RemoteMilvusVectorIOConfig):
logger.info(f"Connecting to Milvus server at {self.config.uri}")
self.client = AsyncMilvusClient(**self.config.model_dump(exclude_none=True))
else:
logger.info(f"Connecting to Milvus Lite at: {self.config.db_path}")
uri = os.path.expanduser(self.config.db_path)
self.client = AsyncMilvusClient(uri=uri)
start_key = VECTOR_DBS_PREFIX
end_key = f"{VECTOR_DBS_PREFIX}\xff"
stored_vector_stores = await self.kvstore.values_in_range(start_key, end_key)
@ -296,26 +334,41 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoc
collection_name=vector_store.identifier,
consistency_level=self.config.consistency_level,
kvstore=self.kvstore,
parent_adapter=self,
),
inference_api=self.inference_api,
)
self.cache[vector_store.identifier] = index
if isinstance(self.config, RemoteMilvusVectorIOConfig):
logger.info(f"Connecting to Milvus server at {self.config.uri}")
self.client = MilvusClient(**self.config.model_dump(exclude_none=True))
else:
logger.info(f"Connecting to Milvus Lite at: {self.config.db_path}")
uri = os.path.expanduser(self.config.db_path)
self.client = MilvusClient(uri=uri)
# Load existing OpenAI vector stores into the in-memory cache
await self.initialize_openai_vector_stores()
async def shutdown(self) -> None:
self.client.close()
if self.client:
await self.client.close()
# Clean up mixin resources (file batch tasks)
await super().shutdown()
async def _recreate_client(self) -> None:
"""Recreate the AsyncMilvusClient when event loop issues occur"""
try:
if self.client:
await self.client.close()
except Exception as e:
logger.warning(f"Error closing old client: {e}")
if isinstance(self.config, RemoteMilvusVectorIOConfig):
logger.info(f"Recreating connection to Milvus server at {self.config.uri}")
self.client = AsyncMilvusClient(**self.config.model_dump(exclude_none=True))
else:
logger.info(f"Recreating connection to Milvus Lite at: {self.config.db_path}")
uri = os.path.expanduser(self.config.db_path)
self.client = AsyncMilvusClient(uri=uri)
for index_wrapper in self.cache.values():
if hasattr(index_wrapper, "index") and hasattr(index_wrapper.index, "client"):
index_wrapper.index.client = self.client
async def register_vector_store(self, vector_store: VectorStore) -> None:
if isinstance(self.config, RemoteMilvusVectorIOConfig):
consistency_level = self.config.consistency_level
@ -323,7 +376,13 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoc
consistency_level = "Strong"
index = VectorStoreWithIndex(
vector_store=vector_store,
index=MilvusIndex(self.client, vector_store.identifier, consistency_level=consistency_level),
index=MilvusIndex(
client=self.client,
collection_name=vector_store.identifier,
consistency_level=consistency_level,
kvstore=self.kvstore,
parent_adapter=self,
),
inference_api=self.inference_api,
)
@ -345,7 +404,12 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoc
vector_store = VectorStore.model_validate_json(vector_store_data)
index = VectorStoreWithIndex(
vector_store=vector_store,
index=MilvusIndex(client=self.client, collection_name=vector_store.identifier, kvstore=self.kvstore),
index=MilvusIndex(
client=self.client,
collection_name=vector_store.identifier,
kvstore=self.kvstore,
parent_adapter=self,
),
inference_api=self.inference_api,
)
self.cache[vector_store_id] = index

8
tests/conftest.py Normal file
View file

@ -0,0 +1,8 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# This file intentionally left empty - pytest will auto-discover conftest.py files
# in subdirectories and load their fixtures automatically.

View file

@ -9,19 +9,23 @@ from unittest.mock import AsyncMock, MagicMock, patch
import numpy as np
import pytest
from pymilvus import AsyncMilvusClient, connections
from llama_stack.core.storage.datatypes import KVStoreReference, SqliteKVStoreConfig
from llama_stack.core.storage.kvstore import register_kvstore_backends
from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig
from llama_stack.providers.inline.vector_io.faiss.faiss import FaissIndex, FaissVectorIOAdapter
from llama_stack.providers.inline.vector_io.milvus.config import MilvusVectorIOConfig
from llama_stack.providers.inline.vector_io.sqlite_vec import SQLiteVectorIOConfig
from llama_stack.providers.inline.vector_io.sqlite_vec.sqlite_vec import SQLiteVecIndex, SQLiteVecVectorIOAdapter
from llama_stack.providers.remote.vector_io.milvus.milvus import MilvusIndex, MilvusVectorIOAdapter
from llama_stack.providers.remote.vector_io.pgvector.config import PGVectorVectorIOConfig
from llama_stack.providers.remote.vector_io.pgvector.pgvector import PGVectorIndex, PGVectorVectorIOAdapter
from llama_stack_api import Chunk, ChunkMetadata, QueryChunksResponse, VectorStore
EMBEDDING_DIMENSION = 768
COLLECTION_PREFIX = "test_collection"
MILVUS_ALIAS = "test_milvus"
@pytest.fixture(params=["sqlite_vec", "faiss", "pgvector"])
@ -140,7 +144,7 @@ async def sqlite_vec_vec_index(embedding_dimension, tmp_path_factory):
await index.initialize()
index.db_path = db_path
yield index
index.delete()
await index.delete()
@pytest.fixture
@ -169,6 +173,48 @@ async def sqlite_vec_adapter(sqlite_vec_db_path, unique_kvstore_config, mock_inf
await adapter.shutdown()
@pytest.fixture(scope="session")
def milvus_vec_db_path(tmp_path_factory):
db_path = str(tmp_path_factory.getbasetemp() / "test_milvus.db")
return db_path
@pytest.fixture
async def milvus_vec_index(milvus_vec_db_path, embedding_dimension):
client = AsyncMilvusClient(milvus_vec_db_path)
name = f"{COLLECTION_PREFIX}_{np.random.randint(1e6)}"
connections.connect(alias=MILVUS_ALIAS, uri=milvus_vec_db_path)
index = MilvusIndex(client, name, consistency_level="Strong")
index.db_path = milvus_vec_db_path
yield index
await client.close()
@pytest.fixture
async def milvus_vec_adapter(milvus_vec_db_path, mock_inference_api):
config = MilvusVectorIOConfig(
db_path=milvus_vec_db_path,
kvstore=SqliteKVStoreConfig(),
)
adapter = MilvusVectorIOAdapter(
config=config,
inference_api=mock_inference_api,
files_api=None,
)
await adapter.initialize()
await adapter.register_vector_store(
VectorStore(
identifier=adapter.metadata_collection_name,
provider_id="test_provider",
embedding_model="test_model",
embedding_dimension=128,
)
)
yield adapter
await adapter.shutdown()
@pytest.fixture
def faiss_vec_db_path(tmp_path_factory):
db_path = str(tmp_path_factory.getbasetemp() / "test_faiss.db")

View file

@ -0,0 +1,383 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from unittest.mock import AsyncMock, MagicMock, patch
import numpy as np
import pytest
from llama_stack.apis.vector_io import QueryChunksResponse
# Mock the entire pymilvus module
pymilvus_mock = MagicMock()
pymilvus_mock.DataType = MagicMock()
pymilvus_mock.AsyncMilvusClient = MagicMock
pymilvus_mock.RRFRanker = MagicMock
pymilvus_mock.WeightedRanker = MagicMock
pymilvus_mock.AnnSearchRequest = MagicMock
# Apply the mock before importing MilvusIndex
with patch.dict("sys.modules", {"pymilvus": pymilvus_mock}):
from llama_stack.providers.remote.vector_io.milvus.milvus import MilvusIndex
# This test is a unit test for the MilvusVectorIOAdapter class. This should only contain
# tests which are specific to this class. More general (API-level) tests should be placed in
# tests/integration/vector_io/
#
# How to run this test:
#
# pytest tests/unit/providers/vector_io/test_milvus.py \
# -v -s --tb=short --disable-warnings --asyncio-mode=auto
MILVUS_PROVIDER = "milvus"
@pytest.fixture
async def mock_milvus_client() -> MagicMock:
"""Create a mock Milvus client with common method behaviors."""
client = MagicMock()
client.list_collections = AsyncMock(return_value=[]) # Initially no collections
client.create_collection = AsyncMock(return_value=None)
client.drop_collection = AsyncMock(return_value=None)
client.insert = AsyncMock(return_value={"insert_count": 10})
client.search = AsyncMock(
return_value=[
[
{
"id": 0,
"distance": 0.1,
"entity": {
"chunk_content": {
"chunk_id": "chunk1",
"content": "mock chunk 1",
"metadata": {"document_id": "doc1"},
}
},
},
{
"id": 1,
"distance": 0.2,
"entity": {
"chunk_content": {
"chunk_id": "chunk2",
"content": "mock chunk 2",
"metadata": {"document_id": "doc2"},
}
},
},
]
]
)
# Mock async query operation for keyword search (data should be dict, not JSON string)
client.query = AsyncMock(
return_value=[
{
"chunk_id": "chunk1",
"chunk_content": {"chunk_id": "chunk1", "content": "mock chunk 1", "metadata": {"document_id": "doc1"}},
"score": 0.9,
},
{
"chunk_id": "chunk2",
"chunk_content": {"chunk_id": "chunk2", "content": "mock chunk 2", "metadata": {"document_id": "doc2"}},
"score": 0.8,
},
{
"chunk_id": "chunk3",
"chunk_content": {"chunk_id": "chunk3", "content": "mock chunk 3", "metadata": {"document_id": "doc3"}},
"score": 0.7,
},
]
)
client.hybrid_search = AsyncMock(return_value=[])
client.delete = AsyncMock(return_value=None)
client.close = AsyncMock(return_value=None)
return client
@pytest.fixture
async def milvus_index(mock_milvus_client):
"""Create a MilvusIndex with mocked client."""
index = MilvusIndex(client=mock_milvus_client, collection_name="test_collection")
yield index
# No real cleanup needed since we're using mocks
async def test_add_chunks(milvus_index, sample_chunks, sample_embeddings, mock_milvus_client):
# Setup: collection doesn't exist initially, then exists after creation
mock_milvus_client.list_collections.side_effect = [[], ["test_collection"]]
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
# Verify collection was created and data was inserted
mock_milvus_client.create_collection.assert_called_once()
mock_milvus_client.insert.assert_called_once()
# Verify the insert call had the right number of chunks
insert_call = mock_milvus_client.insert.call_args
assert len(insert_call[1]["data"]) == len(sample_chunks)
async def test_query_chunks_vector(
milvus_index, sample_chunks, sample_embeddings, embedding_dimension, mock_milvus_client
):
# Setup: Add chunks first
mock_milvus_client.list_collections.return_value = ["test_collection"]
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
# Test vector search
query_embedding = np.random.rand(embedding_dimension).astype(np.float32)
response = await milvus_index.query_vector(query_embedding, k=2, score_threshold=0.0)
assert isinstance(response, QueryChunksResponse)
assert len(response.chunks) == 2
mock_milvus_client.search.assert_called_once()
async def test_query_chunks_keyword_search(milvus_index, sample_chunks, sample_embeddings, mock_milvus_client):
mock_milvus_client.list_collections.return_value = ["test_collection"]
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
# Test keyword search
query_string = "Sentence 5"
response = await milvus_index.query_keyword(query_string=query_string, k=2, score_threshold=0.0)
assert isinstance(response, QueryChunksResponse)
assert len(response.chunks) == 2
async def test_bm25_fallback_to_simple_search(milvus_index, sample_chunks, sample_embeddings, mock_milvus_client):
"""Test that when BM25 search fails, the system falls back to simple text search."""
mock_milvus_client.list_collections.return_value = ["test_collection"]
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
# Force BM25 search to fail
mock_milvus_client.search.side_effect = Exception("BM25 search not available")
# Mock simple text search results
mock_milvus_client.query.return_value = [
{
"chunk_id": "chunk1",
"chunk_content": {
"chunk_id": "chunk1",
"content": "Python programming language",
"metadata": {"document_id": "doc1"},
},
},
{
"chunk_id": "chunk2",
"chunk_content": {
"chunk_id": "chunk2",
"content": "Machine learning algorithms",
"metadata": {"document_id": "doc2"},
},
},
]
# Test keyword search that should fall back to simple text search
query_string = "Python"
response = await milvus_index.query_keyword(query_string=query_string, k=3, score_threshold=0.0)
# Verify response structure
assert isinstance(response, QueryChunksResponse)
assert len(response.chunks) > 0, "Fallback search should return results"
# Verify that simple text search was used (query method called instead of search)
mock_milvus_client.query.assert_called_once()
mock_milvus_client.search.assert_called_once() # Called once but failed
# Verify the query uses parameterized filter with filter_params
query_call_args = mock_milvus_client.query.call_args
assert "filter" in query_call_args[1], "Query should include filter for text search"
assert "filter_params" in query_call_args[1], "Query should use parameterized filter"
assert query_call_args[1]["filter_params"]["content"] == "Python", "Filter params should contain the search term"
# Verify all returned chunks have score 1.0 (simple binary scoring)
assert all(score == 1.0 for score in response.scores), "Simple text search should use binary scoring"
async def test_delete_collection(milvus_index, mock_milvus_client):
# Test collection deletion
mock_milvus_client.list_collections.return_value = ["test_collection"]
await milvus_index.delete()
mock_milvus_client.drop_collection.assert_called_once_with(collection_name=milvus_index.collection_name)
async def test_query_hybrid_search_rrf(
milvus_index, sample_chunks, sample_embeddings, embedding_dimension, mock_milvus_client
):
"""Test hybrid search with RRF reranker."""
mock_milvus_client.list_collections.return_value = ["test_collection"]
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
# Mock hybrid search results
mock_milvus_client.hybrid_search.return_value = [
[
{
"id": 0,
"distance": 0.1,
"entity": {
"chunk_content": {
"chunk_id": "chunk1",
"content": "mock chunk 1",
"metadata": {"document_id": "doc1"},
}
},
},
{
"id": 1,
"distance": 0.2,
"entity": {
"chunk_content": {
"chunk_id": "chunk2",
"content": "mock chunk 2",
"metadata": {"document_id": "doc2"},
}
},
},
]
]
# Test hybrid search with RRF reranker
query_embedding = np.random.rand(embedding_dimension).astype(np.float32)
query_string = "test query"
response = await milvus_index.query_hybrid(
embedding=query_embedding,
query_string=query_string,
k=2,
score_threshold=0.0,
reranker_type="rrf",
reranker_params={"impact_factor": 60.0},
)
assert isinstance(response, QueryChunksResponse)
assert len(response.chunks) == 2
assert len(response.scores) == 2
# Verify hybrid search was called with correct parameters
mock_milvus_client.hybrid_search.assert_called_once()
call_args = mock_milvus_client.hybrid_search.call_args
# Check that the request contains both vector and BM25 search requests
reqs = call_args[1]["reqs"]
assert len(reqs) == 2
assert reqs[0].anns_field == "vector"
assert reqs[1].anns_field == "sparse"
ranker = call_args[1]["ranker"]
assert ranker is not None
async def test_query_hybrid_search_weighted(
milvus_index, sample_chunks, sample_embeddings, embedding_dimension, mock_milvus_client
):
"""Test hybrid search with weighted reranker."""
mock_milvus_client.list_collections.return_value = ["test_collection"]
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
# Mock hybrid search results
mock_milvus_client.hybrid_search.return_value = [
[
{
"id": 0,
"distance": 0.1,
"entity": {
"chunk_content": {
"chunk_id": "chunk1",
"content": "mock chunk 1",
"metadata": {"document_id": "doc1"},
}
},
},
{
"id": 1,
"distance": 0.2,
"entity": {
"chunk_content": {
"chunk_id": "chunk2",
"content": "mock chunk 2",
"metadata": {"document_id": "doc2"},
}
},
},
]
]
# Test hybrid search with weighted reranker
query_embedding = np.random.rand(embedding_dimension).astype(np.float32)
query_string = "test query"
response = await milvus_index.query_hybrid(
embedding=query_embedding,
query_string=query_string,
k=2,
score_threshold=0.0,
reranker_type="weighted",
reranker_params={"alpha": 0.7},
)
assert isinstance(response, QueryChunksResponse)
assert len(response.chunks) == 2
assert len(response.scores) == 2
# Verify hybrid search was called with correct parameters
mock_milvus_client.hybrid_search.assert_called_once()
call_args = mock_milvus_client.hybrid_search.call_args
ranker = call_args[1]["ranker"]
assert ranker is not None
async def test_query_hybrid_search_default_rrf(
milvus_index, sample_chunks, sample_embeddings, embedding_dimension, mock_milvus_client
):
"""Test hybrid search with default RRF reranker (no reranker_type specified)."""
mock_milvus_client.list_collections.return_value = ["test_collection"]
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
# Mock hybrid search results
mock_milvus_client.hybrid_search.return_value = [
[
{
"id": 0,
"distance": 0.1,
"entity": {
"chunk_content": {
"chunk_id": "chunk1",
"content": "mock chunk 1",
"metadata": {"document_id": "doc1"},
}
},
},
]
]
# Test hybrid search with default reranker (should be RRF)
query_embedding = np.random.rand(embedding_dimension).astype(np.float32)
query_string = "test query"
response = await milvus_index.query_hybrid(
embedding=query_embedding,
query_string=query_string,
k=1,
score_threshold=0.0,
reranker_type="unknown_type", # Should default to RRF
reranker_params=None, # Should use default impact_factor
)
assert isinstance(response, QueryChunksResponse)
assert len(response.chunks) == 1
# Verify hybrid search was called with RRF reranker
mock_milvus_client.hybrid_search.assert_called_once()
call_args = mock_milvus_client.hybrid_search.call_args
ranker = call_args[1]["ranker"]
assert ranker is not None

View file

@ -48,12 +48,12 @@ async def test_initialize_index(vector_index):
async def test_add_chunks_query_vector(vector_index, sample_chunks, sample_embeddings):
vector_index.delete()
vector_index.initialize()
await vector_index.delete()
await vector_index.initialize()
await vector_index.add_chunks(sample_chunks, sample_embeddings)
resp = await vector_index.query_vector(sample_embeddings[0], k=1, score_threshold=-1)
assert resp.chunks[0].content == sample_chunks[0].content
vector_index.delete()
await vector_index.delete()
async def test_chunk_id_conflict(vector_index, sample_chunks, embedding_dimension):

6
uv.lock generated
View file

@ -1,5 +1,5 @@
version = 1
revision = 2
revision = 3
requires-python = ">=3.12"
resolution-markers = [
"(python_full_version >= '3.13' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version >= '3.13' and sys_platform != 'darwin' and sys_platform != 'linux')",
@ -2133,9 +2133,11 @@ unit = [
{ name = "faiss-cpu" },
{ name = "litellm" },
{ name = "mcp" },
{ name = "milvus-lite" },
{ name = "moto", extra = ["s3"] },
{ name = "ollama" },
{ name = "psycopg2-binary" },
{ name = "pymilvus" },
{ name = "pypdf" },
{ name = "sqlalchemy", extra = ["asyncio"] },
{ name = "sqlite-vec" },
@ -2277,9 +2279,11 @@ unit = [
{ name = "faiss-cpu" },
{ name = "litellm" },
{ name = "mcp" },
{ name = "milvus-lite", specifier = ">=2.5.0" },
{ name = "moto", extras = ["s3"], specifier = ">=5.1.10" },
{ name = "ollama" },
{ name = "psycopg2-binary", specifier = ">=2.9.0" },
{ name = "pymilvus", specifier = ">=2.6.1" },
{ name = "pypdf", specifier = ">=6.1.3" },
{ name = "sqlalchemy" },
{ name = "sqlalchemy", extras = ["asyncio"], specifier = ">=2.0.41" },