This commit is contained in:
Mustafa Elbehery 2025-10-03 14:11:23 +02:00 committed by GitHub
commit cfe5ac498f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 151 additions and 86 deletions

View file

@ -4,12 +4,11 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import os
from typing import Any
from numpy.typing import NDArray
from pymilvus import AnnSearchRequest, DataType, Function, FunctionType, MilvusClient, RRFRanker, WeightedRanker
from pymilvus import AnnSearchRequest, AsyncMilvusClient, DataType, Function, FunctionType, RRFRanker, WeightedRanker
from llama_stack.apis.common.errors import VectorStoreNotFoundError
from llama_stack.apis.files.files import Files
@ -48,12 +47,18 @@ OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX = f"openai_vector_stores_files_conten
class MilvusIndex(EmbeddingIndex):
def __init__(
self, client: MilvusClient, collection_name: str, consistency_level="Strong", kvstore: KVStore | None = None
self,
client: AsyncMilvusClient,
collection_name: str,
consistency_level="Strong",
kvstore: KVStore | None = None,
parent_adapter=None,
):
self.client = client
self.collection_name = sanitize_collection_name(collection_name)
self.consistency_level = consistency_level
self.kvstore = kvstore
self._parent_adapter = parent_adapter
async def initialize(self):
# MilvusIndex does not require explicit initialization
@ -61,15 +66,39 @@ class MilvusIndex(EmbeddingIndex):
pass
async def delete(self):
if await asyncio.to_thread(self.client.has_collection, self.collection_name):
await asyncio.to_thread(self.client.drop_collection, collection_name=self.collection_name)
try:
collections = await self.client.list_collections()
if self.collection_name in collections:
await self.client.drop_collection(collection_name=self.collection_name)
except Exception as e:
logger.warning(f"Failed to check or delete collection {self.collection_name}: {e}")
async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray):
assert len(chunks) == len(embeddings), (
f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}"
)
if not await asyncio.to_thread(self.client.has_collection, self.collection_name):
try:
collections = await self.client.list_collections()
collection_exists = self.collection_name in collections
except Exception as e:
logger.error(f"Failed to check collection existence: {self.collection_name} ({e})")
# If it's an event loop issue, try to recreate the client
if "attached to a different loop" in str(e):
logger.warning("Recreating client due to event loop issue")
if hasattr(self, "_parent_adapter"):
await self._parent_adapter._recreate_client()
collections = await self.client.list_collections()
collection_exists = self.collection_name in collections
else:
# Assume collection doesn't exist if we can't check
collection_exists = False
else:
# Assume collection doesn't exist if we can't check due to other issues
collection_exists = False
if not collection_exists:
logger.info(f"Creating new collection {self.collection_name} with nullable sparse field")
# Create schema for vector search
schema = self.client.create_schema()
@ -123,13 +152,16 @@ class MilvusIndex(EmbeddingIndex):
)
schema.add_function(bm25_function)
await asyncio.to_thread(
self.client.create_collection,
self.collection_name,
schema=schema,
index_params=index_params,
consistency_level=self.consistency_level,
)
try:
await self.client.create_collection(
self.collection_name,
schema=schema,
index_params=index_params,
consistency_level=self.consistency_level,
)
except Exception as e:
logger.error(f"Failed to create collection {self.collection_name}: {e}")
raise e
data = []
for chunk, embedding in zip(chunks, embeddings, strict=False):
@ -143,8 +175,7 @@ class MilvusIndex(EmbeddingIndex):
}
)
try:
await asyncio.to_thread(
self.client.insert,
await self.client.insert(
self.collection_name,
data=data,
)
@ -153,8 +184,7 @@ class MilvusIndex(EmbeddingIndex):
raise e
async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
search_res = await asyncio.to_thread(
self.client.search,
search_res = await self.client.search(
collection_name=self.collection_name,
data=[embedding],
anns_field="vector",
@ -177,8 +207,7 @@ class MilvusIndex(EmbeddingIndex):
"""
try:
# Use Milvus's built-in BM25 search
search_res = await asyncio.to_thread(
self.client.search,
search_res = await self.client.search(
collection_name=self.collection_name,
data=[query_string], # Raw text query
anns_field="sparse", # Use sparse field for BM25
@ -219,8 +248,7 @@ class MilvusIndex(EmbeddingIndex):
Fallback to simple text search when BM25 search is not available.
"""
# Simple text search using content field
search_res = await asyncio.to_thread(
self.client.query,
search_res = await self.client.query(
collection_name=self.collection_name,
filter='content like "%{content}%"',
filter_params={"content": query_string},
@ -267,8 +295,7 @@ class MilvusIndex(EmbeddingIndex):
impact_factor = (reranker_params or {}).get("impact_factor", 60.0)
rerank = RRFRanker(impact_factor)
search_res = await asyncio.to_thread(
self.client.hybrid_search,
search_res = await self.client.hybrid_search(
collection_name=self.collection_name,
reqs=search_requests,
ranker=rerank,
@ -294,9 +321,7 @@ class MilvusIndex(EmbeddingIndex):
try:
# Use IN clause with square brackets and single quotes for VARCHAR field
chunk_ids_str = ", ".join(f"'{chunk_id}'" for chunk_id in chunk_ids)
await asyncio.to_thread(
self.client.delete, collection_name=self.collection_name, filter=f"chunk_id in [{chunk_ids_str}]"
)
await self.client.delete(collection_name=self.collection_name, filter=f"chunk_id in [{chunk_ids_str}]")
except Exception as e:
logger.error(f"Error deleting chunks from Milvus collection {self.collection_name}: {e}")
raise
@ -321,6 +346,15 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
async def initialize(self) -> None:
self.kvstore = await kvstore_impl(self.config.kvstore)
if isinstance(self.config, RemoteMilvusVectorIOConfig):
logger.info(f"Connecting to Milvus server at {self.config.uri}")
self.client = AsyncMilvusClient(**self.config.model_dump(exclude_none=True))
else:
logger.info(f"Connecting to Milvus Lite at: {self.config.db_path}")
uri = os.path.expanduser(self.config.db_path)
self.client = AsyncMilvusClient(uri=uri)
start_key = VECTOR_DBS_PREFIX
end_key = f"{VECTOR_DBS_PREFIX}\xff"
stored_vector_dbs = await self.kvstore.values_in_range(start_key, end_key)
@ -334,23 +368,38 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
collection_name=vector_db.identifier,
consistency_level=self.config.consistency_level,
kvstore=self.kvstore,
parent_adapter=self,
),
inference_api=self.inference_api,
)
self.cache[vector_db.identifier] = index
if isinstance(self.config, RemoteMilvusVectorIOConfig):
logger.info(f"Connecting to Milvus server at {self.config.uri}")
self.client = MilvusClient(**self.config.model_dump(exclude_none=True))
else:
logger.info(f"Connecting to Milvus Lite at: {self.config.db_path}")
uri = os.path.expanduser(self.config.db_path)
self.client = MilvusClient(uri=uri)
# Load existing OpenAI vector stores into the in-memory cache
await self.initialize_openai_vector_stores()
async def shutdown(self) -> None:
self.client.close()
if self.client:
await self.client.close()
async def _recreate_client(self) -> None:
"""Recreate the AsyncMilvusClient when event loop issues occur"""
try:
if self.client:
await self.client.close()
except Exception as e:
logger.warning(f"Error closing old client: {e}")
if isinstance(self.config, RemoteMilvusVectorIOConfig):
logger.info(f"Recreating connection to Milvus server at {self.config.uri}")
self.client = AsyncMilvusClient(**self.config.model_dump(exclude_none=True))
else:
logger.info(f"Recreating connection to Milvus Lite at: {self.config.db_path}")
uri = os.path.expanduser(self.config.db_path)
self.client = AsyncMilvusClient(uri=uri)
for index_wrapper in self.cache.values():
if hasattr(index_wrapper, "index") and hasattr(index_wrapper.index, "client"):
index_wrapper.index.client = self.client
async def register_vector_db(
self,
@ -362,7 +411,12 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
consistency_level = "Strong"
index = VectorDBWithIndex(
vector_db=vector_db,
index=MilvusIndex(self.client, vector_db.identifier, consistency_level=consistency_level),
index=MilvusIndex(
client=self.client,
collection_name=vector_db.identifier,
consistency_level=consistency_level,
parent_adapter=self,
),
inference_api=self.inference_api,
)
@ -381,7 +435,9 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
index = VectorDBWithIndex(
vector_db=vector_db,
index=MilvusIndex(client=self.client, collection_name=vector_db.identifier, kvstore=self.kvstore),
index=MilvusIndex(
client=self.client, collection_name=vector_db.identifier, kvstore=self.kvstore, parent_adapter=self
),
inference_api=self.inference_api,
)
self.cache[vector_db_id] = index

View file

@ -10,7 +10,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
import numpy as np
import pytest
from chromadb import PersistentClient
from pymilvus import MilvusClient, connections
from pymilvus import AsyncMilvusClient, connections
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import Chunk, ChunkMetadata, QueryChunksResponse
@ -141,7 +141,7 @@ async def sqlite_vec_vec_index(embedding_dimension, tmp_path_factory):
await index.initialize()
index.db_path = db_path
yield index
index.delete()
await index.delete()
@pytest.fixture
@ -178,13 +178,15 @@ def milvus_vec_db_path(tmp_path_factory):
@pytest.fixture
async def milvus_vec_index(milvus_vec_db_path, embedding_dimension):
client = MilvusClient(milvus_vec_db_path)
client = AsyncMilvusClient(milvus_vec_db_path)
name = f"{COLLECTION_PREFIX}_{np.random.randint(1e6)}"
connections.connect(alias=MILVUS_ALIAS, uri=milvus_vec_db_path)
index = MilvusIndex(client, name, consistency_level="Strong")
index.db_path = milvus_vec_db_path
yield index
await client.close()
@pytest.fixture
async def milvus_vec_adapter(milvus_vec_db_path, mock_inference_api):

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from unittest.mock import MagicMock, patch
from unittest.mock import AsyncMock, MagicMock, patch
import numpy as np
import pytest
@ -14,7 +14,7 @@ from llama_stack.apis.vector_io import QueryChunksResponse
# Mock the entire pymilvus module
pymilvus_mock = MagicMock()
pymilvus_mock.DataType = MagicMock()
pymilvus_mock.MilvusClient = MagicMock
pymilvus_mock.AsyncMilvusClient = MagicMock
pymilvus_mock.RRFRanker = MagicMock
pymilvus_mock.WeightedRanker = MagicMock
pymilvus_mock.AnnSearchRequest = MagicMock
@ -40,48 +40,55 @@ async def mock_milvus_client() -> MagicMock:
"""Create a mock Milvus client with common method behaviors."""
client = MagicMock()
# Mock collection operations
client.has_collection.return_value = False # Initially no collection
client.create_collection.return_value = None
client.drop_collection.return_value = None
client.list_collections = AsyncMock(return_value=[]) # Initially no collections
client.create_collection = AsyncMock(return_value=None)
client.drop_collection = AsyncMock(return_value=None)
# Mock insert operation
client.insert.return_value = {"insert_count": 10}
client.insert = AsyncMock(return_value={"insert_count": 10})
# Mock search operation - return mock results (data should be dict, not JSON string)
client.search.return_value = [
[
client.search = AsyncMock(
return_value=[
[
{
"id": 0,
"distance": 0.1,
"entity": {"chunk_content": {"content": "mock chunk 1", "metadata": {"document_id": "doc1"}}},
},
{
"id": 1,
"distance": 0.2,
"entity": {"chunk_content": {"content": "mock chunk 2", "metadata": {"document_id": "doc2"}}},
},
]
]
)
# Mock async query operation for keyword search (data should be dict, not JSON string)
client.query = AsyncMock(
return_value=[
{
"id": 0,
"distance": 0.1,
"entity": {"chunk_content": {"content": "mock chunk 1", "metadata": {"document_id": "doc1"}}},
"chunk_id": "chunk1",
"chunk_content": {"content": "mock chunk 1", "metadata": {"document_id": "doc1"}},
"score": 0.9,
},
{
"id": 1,
"distance": 0.2,
"entity": {"chunk_content": {"content": "mock chunk 2", "metadata": {"document_id": "doc2"}}},
"chunk_id": "chunk2",
"chunk_content": {"content": "mock chunk 2", "metadata": {"document_id": "doc2"}},
"score": 0.8,
},
{
"chunk_id": "chunk3",
"chunk_content": {"content": "mock chunk 3", "metadata": {"document_id": "doc3"}},
"score": 0.7,
},
]
]
)
# Mock query operation for keyword search (data should be dict, not JSON string)
client.query.return_value = [
{
"chunk_id": "chunk1",
"chunk_content": {"content": "mock chunk 1", "metadata": {"document_id": "doc1"}},
"score": 0.9,
},
{
"chunk_id": "chunk2",
"chunk_content": {"content": "mock chunk 2", "metadata": {"document_id": "doc2"}},
"score": 0.8,
},
{
"chunk_id": "chunk3",
"chunk_content": {"content": "mock chunk 3", "metadata": {"document_id": "doc3"}},
"score": 0.7,
},
]
client.hybrid_search = AsyncMock(return_value=[])
client.delete = AsyncMock(return_value=None)
client.close = AsyncMock(return_value=None)
return client
@ -96,7 +103,7 @@ async def milvus_index(mock_milvus_client):
async def test_add_chunks(milvus_index, sample_chunks, sample_embeddings, mock_milvus_client):
# Setup: collection doesn't exist initially, then exists after creation
mock_milvus_client.has_collection.side_effect = [False, True]
mock_milvus_client.list_collections.side_effect = [[], ["test_collection"]]
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
@ -113,7 +120,7 @@ async def test_query_chunks_vector(
milvus_index, sample_chunks, sample_embeddings, embedding_dimension, mock_milvus_client
):
# Setup: Add chunks first
mock_milvus_client.has_collection.return_value = True
mock_milvus_client.list_collections.return_value = ["test_collection"]
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
# Test vector search
@ -126,7 +133,7 @@ async def test_query_chunks_vector(
async def test_query_chunks_keyword_search(milvus_index, sample_chunks, sample_embeddings, mock_milvus_client):
mock_milvus_client.has_collection.return_value = True
mock_milvus_client.list_collections.return_value = ["test_collection"]
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
# Test keyword search
@ -139,7 +146,7 @@ async def test_query_chunks_keyword_search(milvus_index, sample_chunks, sample_e
async def test_bm25_fallback_to_simple_search(milvus_index, sample_chunks, sample_embeddings, mock_milvus_client):
"""Test that when BM25 search fails, the system falls back to simple text search."""
mock_milvus_client.has_collection.return_value = True
mock_milvus_client.list_collections.return_value = ["test_collection"]
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
# Force BM25 search to fail
@ -181,7 +188,7 @@ async def test_bm25_fallback_to_simple_search(milvus_index, sample_chunks, sampl
async def test_delete_collection(milvus_index, mock_milvus_client):
# Test collection deletion
mock_milvus_client.has_collection.return_value = True
mock_milvus_client.list_collections.return_value = ["test_collection"]
await milvus_index.delete()
@ -192,7 +199,7 @@ async def test_query_hybrid_search_rrf(
milvus_index, sample_chunks, sample_embeddings, embedding_dimension, mock_milvus_client
):
"""Test hybrid search with RRF reranker."""
mock_milvus_client.has_collection.return_value = True
mock_milvus_client.list_collections.return_value = ["test_collection"]
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
# Mock hybrid search results
@ -244,7 +251,7 @@ async def test_query_hybrid_search_weighted(
milvus_index, sample_chunks, sample_embeddings, embedding_dimension, mock_milvus_client
):
"""Test hybrid search with weighted reranker."""
mock_milvus_client.has_collection.return_value = True
mock_milvus_client.list_collections.return_value = ["test_collection"]
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
# Mock hybrid search results
@ -290,7 +297,7 @@ async def test_query_hybrid_search_default_rrf(
milvus_index, sample_chunks, sample_embeddings, embedding_dimension, mock_milvus_client
):
"""Test hybrid search with default RRF reranker (no reranker_type specified)."""
mock_milvus_client.has_collection.return_value = True
mock_milvus_client.list_collections.return_value = ["test_collection"]
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
# Mock hybrid search results

View file

@ -30,12 +30,12 @@ async def test_initialize_index(vector_index):
async def test_add_chunks_query_vector(vector_index, sample_chunks, sample_embeddings):
vector_index.delete()
vector_index.initialize()
await vector_index.delete()
await vector_index.initialize()
await vector_index.add_chunks(sample_chunks, sample_embeddings)
resp = await vector_index.query_vector(sample_embeddings[0], k=1, score_threshold=-1)
assert resp.chunks[0].content == sample_chunks[0].content
vector_index.delete()
await vector_index.delete()
async def test_chunk_id_conflict(vector_index, sample_chunks, embedding_dimension):