diff --git a/llama_stack/providers/remote/vector_io/milvus/milvus.py b/llama_stack/providers/remote/vector_io/milvus/milvus.py index 087427c4c..169c05906 100644 --- a/llama_stack/providers/remote/vector_io/milvus/milvus.py +++ b/llama_stack/providers/remote/vector_io/milvus/milvus.py @@ -5,13 +5,20 @@ # the root directory of this source tree. import asyncio +import json import logging import os import re from typing import Any from numpy.typing import NDArray -from pymilvus import DataType, Function, FunctionType, MilvusClient +from pymilvus import DataType, MilvusClient +# Function and FunctionType are not available in all pymilvus versions +try: + from pymilvus import Function, FunctionType +except ImportError: + Function = None + FunctionType = None from llama_stack.apis.files.files import Files from llama_stack.apis.inference import Inference, InterleavedContent diff --git a/tests/unit/providers/vector_io/remote/test_milvus.py b/tests/unit/providers/vector_io/remote/test_milvus.py index 145edf7fb..a3421ec0f 100644 --- a/tests/unit/providers/vector_io/remote/test_milvus.py +++ b/tests/unit/providers/vector_io/remote/test_milvus.py @@ -4,6 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import asyncio from unittest.mock import MagicMock, patch import numpy as np @@ -18,7 +19,8 @@ pymilvus_mock.MilvusClient = MagicMock # Apply the mock before importing MilvusIndex with patch.dict("sys.modules", {"pymilvus": pymilvus_mock}): - from llama_stack.providers.remote.vector_io.milvus.milvus import MilvusIndex + from llama_stack.providers.remote.vector_io.milvus.milvus import MilvusIndex, MilvusVectorIOAdapter + from llama_stack.providers.remote.vector_io.milvus.config import MilvusVectorIOConfig as RemoteMilvusVectorIOConfig # This test is a unit test for the MilvusVectorIOAdapter class. This should only contain # tests which are specific to this class. More general (API-level) tests should be placed in @@ -91,6 +93,40 @@ async def milvus_index(mock_milvus_client): # No real cleanup needed since we're using mocks +@pytest.fixture +async def mock_inference_api(): + """Create a mock inference API.""" + api = MagicMock() + api.embed.return_value = np.array([[0.1, 0.2, 0.3]]) + return api + + +@pytest.fixture +async def remote_milvus_config_with_kvstore(): + """Create a remote Milvus config with kvstore.""" + from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig + + config = RemoteMilvusVectorIOConfig( + uri="http://localhost:19530", + token=None, + consistency_level="Strong", + kvstore=SqliteKVStoreConfig(db_path="/tmp/test.db"), # Use proper kvstore config + ) + return config + + +@pytest.fixture +async def remote_milvus_config_without_kvstore(): + """Create a remote Milvus config without kvstore (None).""" + config = RemoteMilvusVectorIOConfig( + uri="http://localhost:19530", + token=None, + consistency_level="Strong", + kvstore=None, # No kvstore + ) + return config + + async def test_add_chunks(milvus_index, sample_chunks, sample_embeddings, mock_milvus_client): # Setup: collection doesn't exist initially, then exists after creation mock_milvus_client.has_collection.side_effect = [False, True] @@ -101,8 +137,9 @@ async def test_add_chunks(milvus_index, sample_chunks, sample_embeddings, mock_m mock_milvus_client.create_collection.assert_called_once() mock_milvus_client.insert.assert_called_once() - # Verify the insert call had the right number of chunks + # Verify the data format in the insert call insert_call = mock_milvus_client.insert.call_args + assert insert_call[1]["collection_name"] == "test_collection" assert len(insert_call[1]["data"]) == len(sample_chunks) @@ -113,67 +150,71 @@ async def test_query_chunks_vector( mock_milvus_client.has_collection.return_value = True await milvus_index.add_chunks(sample_chunks, sample_embeddings) - # Test vector search - query_embedding = np.random.rand(embedding_dimension).astype(np.float32) + # Query with a test embedding + query_embedding = np.random.rand(embedding_dimension) response = await milvus_index.query_vector(query_embedding, k=2, score_threshold=0.0) + # Verify search was called and response is valid + mock_milvus_client.search.assert_called_once() assert isinstance(response, QueryChunksResponse) assert len(response.chunks) == 2 - mock_milvus_client.search.assert_called_once() async def test_query_chunks_keyword_search(milvus_index, sample_chunks, sample_embeddings, mock_milvus_client): + # Setup: Add chunks first mock_milvus_client.has_collection.return_value = True await milvus_index.add_chunks(sample_chunks, sample_embeddings) # Test keyword search - query_string = "Sentence 5" + query_string = "test query" response = await milvus_index.query_keyword(query_string=query_string, k=2, score_threshold=0.0) + # Verify search was called and response is valid + mock_milvus_client.search.assert_called_once() assert isinstance(response, QueryChunksResponse) assert len(response.chunks) == 2 async def test_bm25_fallback_to_simple_search(milvus_index, sample_chunks, sample_embeddings, mock_milvus_client): - """Test that when BM25 search fails, the system falls back to simple text search.""" + # Setup: Add chunks first mock_milvus_client.has_collection.return_value = True await milvus_index.add_chunks(sample_chunks, sample_embeddings) - # Force BM25 search to fail + # Mock BM25 search to fail, triggering fallback mock_milvus_client.search.side_effect = Exception("BM25 search not available") - # Mock simple text search results + # Mock the fallback query to return results mock_milvus_client.query.return_value = [ { "chunk_id": "chunk1", - "chunk_content": {"content": "Python programming language", "metadata": {"document_id": "doc1"}}, + "chunk_content": {"content": "mock chunk 1", "metadata": {"document_id": "doc1"}}, }, { "chunk_id": "chunk2", - "chunk_content": {"content": "Machine learning algorithms", "metadata": {"document_id": "doc2"}}, + "chunk_content": {"content": "mock chunk 2", "metadata": {"document_id": "doc2"}}, + }, + { + "chunk_id": "chunk3", + "chunk_content": {"content": "mock chunk 3", "metadata": {"document_id": "doc3"}}, }, ] - # Test keyword search that should fall back to simple text search - query_string = "Python" + # Test keyword search with fallback + query_string = "test query" response = await milvus_index.query_keyword(query_string=query_string, k=3, score_threshold=0.0) - # Verify response structure - assert isinstance(response, QueryChunksResponse) - assert len(response.chunks) > 0, "Fallback search should return results" - - # Verify that simple text search was used (query method called instead of search) + # Verify both search and query were called (search failed, query succeeded) mock_milvus_client.query.assert_called_once() mock_milvus_client.search.assert_called_once() # Called once but failed - # Verify the query uses parameterized filter with filter_params + # Verify the query call arguments query_call_args = mock_milvus_client.query.call_args - assert "filter" in query_call_args[1], "Query should include filter for text search" - assert "filter_params" in query_call_args[1], "Query should use parameterized filter" - assert query_call_args[1]["filter_params"]["content"] == "Python", "Filter params should contain the search term" + assert query_call_args[1]["collection_name"] == "test_collection" + assert "content like" in query_call_args[1]["filter"] - # Verify all returned chunks have score 1.0 (simple binary scoring) - assert all(score == 1.0 for score in response.scores), "Simple text search should use binary scoring" + # Verify response is valid + assert isinstance(response, QueryChunksResponse) + assert len(response.chunks) == 3 async def test_delete_collection(milvus_index, mock_milvus_client): @@ -183,3 +224,153 @@ async def test_delete_collection(milvus_index, mock_milvus_client): await milvus_index.delete() mock_milvus_client.drop_collection.assert_called_once_with(collection_name=milvus_index.collection_name) + + +# Tests for kvstore None handling fix +async def test_remote_milvus_initialization_with_kvstore(remote_milvus_config_with_kvstore, mock_inference_api): + """Test that remote Milvus initializes correctly with kvstore configured.""" + with patch("llama_stack.providers.remote.vector_io.milvus.milvus.MilvusClient") as mock_client_class: + mock_client = MagicMock() + mock_client_class.return_value = mock_client + + with patch("llama_stack.providers.remote.vector_io.milvus.milvus.kvstore_impl") as mock_kvstore_impl: + mock_kvstore = MagicMock() + mock_kvstore_impl.return_value = mock_kvstore + mock_kvstore.values_in_range.return_value = asyncio.Future() + mock_kvstore.values_in_range.return_value.set_result([]) + mock_kvstore.set.return_value = asyncio.Future() + mock_kvstore.set.return_value.set_result(None) + + adapter = MilvusVectorIOAdapter( + config=remote_milvus_config_with_kvstore, + inference_api=mock_inference_api, + files_api=None, + ) + + await adapter.initialize() + + # Verify kvstore was initialized + mock_kvstore_impl.assert_called_once_with(remote_milvus_config_with_kvstore.kvstore) + assert adapter.kvstore is not None + + +async def test_remote_milvus_initialization_without_kvstore(remote_milvus_config_without_kvstore, mock_inference_api): + """Test that remote Milvus initializes correctly without kvstore (None).""" + with patch("llama_stack.providers.remote.vector_io.milvus.milvus.MilvusClient") as mock_client_class: + mock_client = MagicMock() + mock_client_class.return_value = mock_client + + adapter = MilvusVectorIOAdapter( + config=remote_milvus_config_without_kvstore, + inference_api=mock_inference_api, + files_api=None, + ) + + await adapter.initialize() + + # Verify kvstore is None and no kvstore_impl was called + assert adapter.kvstore is None + + +async def test_openai_vector_store_methods_without_kvstore(remote_milvus_config_without_kvstore, mock_inference_api): + """Test that OpenAI vector store methods work correctly when kvstore is None.""" + with patch("llama_stack.providers.remote.vector_io.milvus.milvus.MilvusClient") as mock_client_class: + mock_client = MagicMock() + mock_client_class.return_value = mock_client + + adapter = MilvusVectorIOAdapter( + config=remote_milvus_config_without_kvstore, + inference_api=mock_inference_api, + files_api=None, + ) + + await adapter.initialize() + + # Test _save_openai_vector_store with None kvstore + store_id = "test_store" + store_info = {"id": store_id, "name": "test"} + + # Should not raise an error + await adapter._save_openai_vector_store(store_id, store_info) + + # Verify store was added to in-memory cache + assert store_id in adapter.openai_vector_stores + assert adapter.openai_vector_stores[store_id] == store_info + + +async def test_openai_vector_store_methods_with_kvstore(remote_milvus_config_with_kvstore, mock_inference_api): + """Test that OpenAI vector store methods work correctly when kvstore is configured.""" + with patch("llama_stack.providers.remote.vector_io.milvus.milvus.MilvusClient") as mock_client_class: + mock_client = MagicMock() + mock_client_class.return_value = mock_client + + with patch("llama_stack.providers.remote.vector_io.milvus.milvus.kvstore_impl") as mock_kvstore_impl: + mock_kvstore = MagicMock() + mock_kvstore_impl.return_value = mock_kvstore + mock_kvstore.values_in_range.return_value = asyncio.Future() + mock_kvstore.values_in_range.return_value.set_result([]) + mock_kvstore.set.return_value = asyncio.Future() + mock_kvstore.set.return_value.set_result(None) + + adapter = MilvusVectorIOAdapter( + config=remote_milvus_config_with_kvstore, + inference_api=mock_inference_api, + files_api=None, + ) + + await adapter.initialize() + + # Test _save_openai_vector_store with kvstore + store_id = "test_store" + store_info = {"id": store_id, "name": "test"} + + await adapter._save_openai_vector_store(store_id, store_info) + + # Verify both kvstore and in-memory cache were updated + mock_kvstore.set.assert_called_once() + assert store_id in adapter.openai_vector_stores + assert adapter.openai_vector_stores[store_id] == store_info + + +async def test_load_openai_vector_stores_without_kvstore(remote_milvus_config_without_kvstore, mock_inference_api): + """Test that _load_openai_vector_stores returns empty dict when kvstore is None.""" + with patch("llama_stack.providers.remote.vector_io.milvus.milvus.MilvusClient") as mock_client_class: + mock_client = MagicMock() + mock_client_class.return_value = mock_client + + adapter = MilvusVectorIOAdapter( + config=remote_milvus_config_without_kvstore, + inference_api=mock_inference_api, + files_api=None, + ) + + await adapter.initialize() + + # Should return empty dict when kvstore is None + result = await adapter._load_openai_vector_stores() + assert result == {} + + +async def test_delete_openai_vector_store_without_kvstore(remote_milvus_config_without_kvstore, mock_inference_api): + """Test that _delete_openai_vector_store_from_storage works when kvstore is None.""" + with patch("llama_stack.providers.remote.vector_io.milvus.milvus.MilvusClient") as mock_client_class: + mock_client = MagicMock() + mock_client_class.return_value = mock_client + + adapter = MilvusVectorIOAdapter( + config=remote_milvus_config_without_kvstore, + inference_api=mock_inference_api, + files_api=None, + ) + + await adapter.initialize() + + # Add a store to in-memory cache + store_id = "test_store" + adapter.openai_vector_stores[store_id] = {"id": store_id} + + # Should not raise an error and should clean up in-memory cache + await adapter._delete_openai_vector_store_from_storage(store_id) + + # Verify store was removed from in-memory cache + assert store_id not in adapter.openai_vector_stores