This commit is contained in:
Mustafa Elbehery 2025-10-03 14:11:23 +02:00 committed by GitHub
commit cfe5ac498f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 151 additions and 86 deletions

View file

@ -4,12 +4,11 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import asyncio
import os import os
from typing import Any from typing import Any
from numpy.typing import NDArray from numpy.typing import NDArray
from pymilvus import AnnSearchRequest, DataType, Function, FunctionType, MilvusClient, RRFRanker, WeightedRanker from pymilvus import AnnSearchRequest, AsyncMilvusClient, DataType, Function, FunctionType, RRFRanker, WeightedRanker
from llama_stack.apis.common.errors import VectorStoreNotFoundError from llama_stack.apis.common.errors import VectorStoreNotFoundError
from llama_stack.apis.files.files import Files from llama_stack.apis.files.files import Files
@ -48,12 +47,18 @@ OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX = f"openai_vector_stores_files_conten
class MilvusIndex(EmbeddingIndex): class MilvusIndex(EmbeddingIndex):
def __init__( def __init__(
self, client: MilvusClient, collection_name: str, consistency_level="Strong", kvstore: KVStore | None = None self,
client: AsyncMilvusClient,
collection_name: str,
consistency_level="Strong",
kvstore: KVStore | None = None,
parent_adapter=None,
): ):
self.client = client self.client = client
self.collection_name = sanitize_collection_name(collection_name) self.collection_name = sanitize_collection_name(collection_name)
self.consistency_level = consistency_level self.consistency_level = consistency_level
self.kvstore = kvstore self.kvstore = kvstore
self._parent_adapter = parent_adapter
async def initialize(self): async def initialize(self):
# MilvusIndex does not require explicit initialization # MilvusIndex does not require explicit initialization
@ -61,15 +66,39 @@ class MilvusIndex(EmbeddingIndex):
pass pass
async def delete(self): async def delete(self):
if await asyncio.to_thread(self.client.has_collection, self.collection_name): try:
await asyncio.to_thread(self.client.drop_collection, collection_name=self.collection_name) collections = await self.client.list_collections()
if self.collection_name in collections:
await self.client.drop_collection(collection_name=self.collection_name)
except Exception as e:
logger.warning(f"Failed to check or delete collection {self.collection_name}: {e}")
async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray): async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray):
assert len(chunks) == len(embeddings), ( assert len(chunks) == len(embeddings), (
f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}" f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}"
) )
if not await asyncio.to_thread(self.client.has_collection, self.collection_name): try:
collections = await self.client.list_collections()
collection_exists = self.collection_name in collections
except Exception as e:
logger.error(f"Failed to check collection existence: {self.collection_name} ({e})")
# If it's an event loop issue, try to recreate the client
if "attached to a different loop" in str(e):
logger.warning("Recreating client due to event loop issue")
if hasattr(self, "_parent_adapter"):
await self._parent_adapter._recreate_client()
collections = await self.client.list_collections()
collection_exists = self.collection_name in collections
else:
# Assume collection doesn't exist if we can't check
collection_exists = False
else:
# Assume collection doesn't exist if we can't check due to other issues
collection_exists = False
if not collection_exists:
logger.info(f"Creating new collection {self.collection_name} with nullable sparse field") logger.info(f"Creating new collection {self.collection_name} with nullable sparse field")
# Create schema for vector search # Create schema for vector search
schema = self.client.create_schema() schema = self.client.create_schema()
@ -123,13 +152,16 @@ class MilvusIndex(EmbeddingIndex):
) )
schema.add_function(bm25_function) schema.add_function(bm25_function)
await asyncio.to_thread( try:
self.client.create_collection, await self.client.create_collection(
self.collection_name, self.collection_name,
schema=schema, schema=schema,
index_params=index_params, index_params=index_params,
consistency_level=self.consistency_level, consistency_level=self.consistency_level,
) )
except Exception as e:
logger.error(f"Failed to create collection {self.collection_name}: {e}")
raise e
data = [] data = []
for chunk, embedding in zip(chunks, embeddings, strict=False): for chunk, embedding in zip(chunks, embeddings, strict=False):
@ -143,8 +175,7 @@ class MilvusIndex(EmbeddingIndex):
} }
) )
try: try:
await asyncio.to_thread( await self.client.insert(
self.client.insert,
self.collection_name, self.collection_name,
data=data, data=data,
) )
@ -153,8 +184,7 @@ class MilvusIndex(EmbeddingIndex):
raise e raise e
async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse: async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
search_res = await asyncio.to_thread( search_res = await self.client.search(
self.client.search,
collection_name=self.collection_name, collection_name=self.collection_name,
data=[embedding], data=[embedding],
anns_field="vector", anns_field="vector",
@ -177,8 +207,7 @@ class MilvusIndex(EmbeddingIndex):
""" """
try: try:
# Use Milvus's built-in BM25 search # Use Milvus's built-in BM25 search
search_res = await asyncio.to_thread( search_res = await self.client.search(
self.client.search,
collection_name=self.collection_name, collection_name=self.collection_name,
data=[query_string], # Raw text query data=[query_string], # Raw text query
anns_field="sparse", # Use sparse field for BM25 anns_field="sparse", # Use sparse field for BM25
@ -219,8 +248,7 @@ class MilvusIndex(EmbeddingIndex):
Fallback to simple text search when BM25 search is not available. Fallback to simple text search when BM25 search is not available.
""" """
# Simple text search using content field # Simple text search using content field
search_res = await asyncio.to_thread( search_res = await self.client.query(
self.client.query,
collection_name=self.collection_name, collection_name=self.collection_name,
filter='content like "%{content}%"', filter='content like "%{content}%"',
filter_params={"content": query_string}, filter_params={"content": query_string},
@ -267,8 +295,7 @@ class MilvusIndex(EmbeddingIndex):
impact_factor = (reranker_params or {}).get("impact_factor", 60.0) impact_factor = (reranker_params or {}).get("impact_factor", 60.0)
rerank = RRFRanker(impact_factor) rerank = RRFRanker(impact_factor)
search_res = await asyncio.to_thread( search_res = await self.client.hybrid_search(
self.client.hybrid_search,
collection_name=self.collection_name, collection_name=self.collection_name,
reqs=search_requests, reqs=search_requests,
ranker=rerank, ranker=rerank,
@ -294,9 +321,7 @@ class MilvusIndex(EmbeddingIndex):
try: try:
# Use IN clause with square brackets and single quotes for VARCHAR field # Use IN clause with square brackets and single quotes for VARCHAR field
chunk_ids_str = ", ".join(f"'{chunk_id}'" for chunk_id in chunk_ids) chunk_ids_str = ", ".join(f"'{chunk_id}'" for chunk_id in chunk_ids)
await asyncio.to_thread( await self.client.delete(collection_name=self.collection_name, filter=f"chunk_id in [{chunk_ids_str}]")
self.client.delete, collection_name=self.collection_name, filter=f"chunk_id in [{chunk_ids_str}]"
)
except Exception as e: except Exception as e:
logger.error(f"Error deleting chunks from Milvus collection {self.collection_name}: {e}") logger.error(f"Error deleting chunks from Milvus collection {self.collection_name}: {e}")
raise raise
@ -321,6 +346,15 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
async def initialize(self) -> None: async def initialize(self) -> None:
self.kvstore = await kvstore_impl(self.config.kvstore) self.kvstore = await kvstore_impl(self.config.kvstore)
if isinstance(self.config, RemoteMilvusVectorIOConfig):
logger.info(f"Connecting to Milvus server at {self.config.uri}")
self.client = AsyncMilvusClient(**self.config.model_dump(exclude_none=True))
else:
logger.info(f"Connecting to Milvus Lite at: {self.config.db_path}")
uri = os.path.expanduser(self.config.db_path)
self.client = AsyncMilvusClient(uri=uri)
start_key = VECTOR_DBS_PREFIX start_key = VECTOR_DBS_PREFIX
end_key = f"{VECTOR_DBS_PREFIX}\xff" end_key = f"{VECTOR_DBS_PREFIX}\xff"
stored_vector_dbs = await self.kvstore.values_in_range(start_key, end_key) stored_vector_dbs = await self.kvstore.values_in_range(start_key, end_key)
@ -334,23 +368,38 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
collection_name=vector_db.identifier, collection_name=vector_db.identifier,
consistency_level=self.config.consistency_level, consistency_level=self.config.consistency_level,
kvstore=self.kvstore, kvstore=self.kvstore,
parent_adapter=self,
), ),
inference_api=self.inference_api, inference_api=self.inference_api,
) )
self.cache[vector_db.identifier] = index self.cache[vector_db.identifier] = index
if isinstance(self.config, RemoteMilvusVectorIOConfig):
logger.info(f"Connecting to Milvus server at {self.config.uri}")
self.client = MilvusClient(**self.config.model_dump(exclude_none=True))
else:
logger.info(f"Connecting to Milvus Lite at: {self.config.db_path}")
uri = os.path.expanduser(self.config.db_path)
self.client = MilvusClient(uri=uri)
# Load existing OpenAI vector stores into the in-memory cache # Load existing OpenAI vector stores into the in-memory cache
await self.initialize_openai_vector_stores() await self.initialize_openai_vector_stores()
async def shutdown(self) -> None: async def shutdown(self) -> None:
self.client.close() if self.client:
await self.client.close()
async def _recreate_client(self) -> None:
"""Recreate the AsyncMilvusClient when event loop issues occur"""
try:
if self.client:
await self.client.close()
except Exception as e:
logger.warning(f"Error closing old client: {e}")
if isinstance(self.config, RemoteMilvusVectorIOConfig):
logger.info(f"Recreating connection to Milvus server at {self.config.uri}")
self.client = AsyncMilvusClient(**self.config.model_dump(exclude_none=True))
else:
logger.info(f"Recreating connection to Milvus Lite at: {self.config.db_path}")
uri = os.path.expanduser(self.config.db_path)
self.client = AsyncMilvusClient(uri=uri)
for index_wrapper in self.cache.values():
if hasattr(index_wrapper, "index") and hasattr(index_wrapper.index, "client"):
index_wrapper.index.client = self.client
async def register_vector_db( async def register_vector_db(
self, self,
@ -362,7 +411,12 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
consistency_level = "Strong" consistency_level = "Strong"
index = VectorDBWithIndex( index = VectorDBWithIndex(
vector_db=vector_db, vector_db=vector_db,
index=MilvusIndex(self.client, vector_db.identifier, consistency_level=consistency_level), index=MilvusIndex(
client=self.client,
collection_name=vector_db.identifier,
consistency_level=consistency_level,
parent_adapter=self,
),
inference_api=self.inference_api, inference_api=self.inference_api,
) )
@ -381,7 +435,9 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
index = VectorDBWithIndex( index = VectorDBWithIndex(
vector_db=vector_db, vector_db=vector_db,
index=MilvusIndex(client=self.client, collection_name=vector_db.identifier, kvstore=self.kvstore), index=MilvusIndex(
client=self.client, collection_name=vector_db.identifier, kvstore=self.kvstore, parent_adapter=self
),
inference_api=self.inference_api, inference_api=self.inference_api,
) )
self.cache[vector_db_id] = index self.cache[vector_db_id] = index

View file

@ -10,7 +10,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
import numpy as np import numpy as np
import pytest import pytest
from chromadb import PersistentClient from chromadb import PersistentClient
from pymilvus import MilvusClient, connections from pymilvus import AsyncMilvusClient, connections
from llama_stack.apis.vector_dbs import VectorDB from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import Chunk, ChunkMetadata, QueryChunksResponse from llama_stack.apis.vector_io import Chunk, ChunkMetadata, QueryChunksResponse
@ -141,7 +141,7 @@ async def sqlite_vec_vec_index(embedding_dimension, tmp_path_factory):
await index.initialize() await index.initialize()
index.db_path = db_path index.db_path = db_path
yield index yield index
index.delete() await index.delete()
@pytest.fixture @pytest.fixture
@ -178,13 +178,15 @@ def milvus_vec_db_path(tmp_path_factory):
@pytest.fixture @pytest.fixture
async def milvus_vec_index(milvus_vec_db_path, embedding_dimension): async def milvus_vec_index(milvus_vec_db_path, embedding_dimension):
client = MilvusClient(milvus_vec_db_path) client = AsyncMilvusClient(milvus_vec_db_path)
name = f"{COLLECTION_PREFIX}_{np.random.randint(1e6)}" name = f"{COLLECTION_PREFIX}_{np.random.randint(1e6)}"
connections.connect(alias=MILVUS_ALIAS, uri=milvus_vec_db_path) connections.connect(alias=MILVUS_ALIAS, uri=milvus_vec_db_path)
index = MilvusIndex(client, name, consistency_level="Strong") index = MilvusIndex(client, name, consistency_level="Strong")
index.db_path = milvus_vec_db_path index.db_path = milvus_vec_db_path
yield index yield index
await client.close()
@pytest.fixture @pytest.fixture
async def milvus_vec_adapter(milvus_vec_db_path, mock_inference_api): async def milvus_vec_adapter(milvus_vec_db_path, mock_inference_api):

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from unittest.mock import MagicMock, patch from unittest.mock import AsyncMock, MagicMock, patch
import numpy as np import numpy as np
import pytest import pytest
@ -14,7 +14,7 @@ from llama_stack.apis.vector_io import QueryChunksResponse
# Mock the entire pymilvus module # Mock the entire pymilvus module
pymilvus_mock = MagicMock() pymilvus_mock = MagicMock()
pymilvus_mock.DataType = MagicMock() pymilvus_mock.DataType = MagicMock()
pymilvus_mock.MilvusClient = MagicMock pymilvus_mock.AsyncMilvusClient = MagicMock
pymilvus_mock.RRFRanker = MagicMock pymilvus_mock.RRFRanker = MagicMock
pymilvus_mock.WeightedRanker = MagicMock pymilvus_mock.WeightedRanker = MagicMock
pymilvus_mock.AnnSearchRequest = MagicMock pymilvus_mock.AnnSearchRequest = MagicMock
@ -40,48 +40,55 @@ async def mock_milvus_client() -> MagicMock:
"""Create a mock Milvus client with common method behaviors.""" """Create a mock Milvus client with common method behaviors."""
client = MagicMock() client = MagicMock()
# Mock collection operations client.list_collections = AsyncMock(return_value=[]) # Initially no collections
client.has_collection.return_value = False # Initially no collection client.create_collection = AsyncMock(return_value=None)
client.create_collection.return_value = None client.drop_collection = AsyncMock(return_value=None)
client.drop_collection.return_value = None
# Mock insert operation client.insert = AsyncMock(return_value={"insert_count": 10})
client.insert.return_value = {"insert_count": 10}
# Mock search operation - return mock results (data should be dict, not JSON string) client.search = AsyncMock(
client.search.return_value = [ return_value=[
[ [
{
"id": 0,
"distance": 0.1,
"entity": {"chunk_content": {"content": "mock chunk 1", "metadata": {"document_id": "doc1"}}},
},
{
"id": 1,
"distance": 0.2,
"entity": {"chunk_content": {"content": "mock chunk 2", "metadata": {"document_id": "doc2"}}},
},
]
]
)
# Mock async query operation for keyword search (data should be dict, not JSON string)
client.query = AsyncMock(
return_value=[
{ {
"id": 0, "chunk_id": "chunk1",
"distance": 0.1, "chunk_content": {"content": "mock chunk 1", "metadata": {"document_id": "doc1"}},
"entity": {"chunk_content": {"content": "mock chunk 1", "metadata": {"document_id": "doc1"}}}, "score": 0.9,
}, },
{ {
"id": 1, "chunk_id": "chunk2",
"distance": 0.2, "chunk_content": {"content": "mock chunk 2", "metadata": {"document_id": "doc2"}},
"entity": {"chunk_content": {"content": "mock chunk 2", "metadata": {"document_id": "doc2"}}}, "score": 0.8,
},
{
"chunk_id": "chunk3",
"chunk_content": {"content": "mock chunk 3", "metadata": {"document_id": "doc3"}},
"score": 0.7,
}, },
] ]
] )
# Mock query operation for keyword search (data should be dict, not JSON string) client.hybrid_search = AsyncMock(return_value=[])
client.query.return_value = [
{ client.delete = AsyncMock(return_value=None)
"chunk_id": "chunk1",
"chunk_content": {"content": "mock chunk 1", "metadata": {"document_id": "doc1"}}, client.close = AsyncMock(return_value=None)
"score": 0.9,
},
{
"chunk_id": "chunk2",
"chunk_content": {"content": "mock chunk 2", "metadata": {"document_id": "doc2"}},
"score": 0.8,
},
{
"chunk_id": "chunk3",
"chunk_content": {"content": "mock chunk 3", "metadata": {"document_id": "doc3"}},
"score": 0.7,
},
]
return client return client
@ -96,7 +103,7 @@ async def milvus_index(mock_milvus_client):
async def test_add_chunks(milvus_index, sample_chunks, sample_embeddings, mock_milvus_client): async def test_add_chunks(milvus_index, sample_chunks, sample_embeddings, mock_milvus_client):
# Setup: collection doesn't exist initially, then exists after creation # Setup: collection doesn't exist initially, then exists after creation
mock_milvus_client.has_collection.side_effect = [False, True] mock_milvus_client.list_collections.side_effect = [[], ["test_collection"]]
await milvus_index.add_chunks(sample_chunks, sample_embeddings) await milvus_index.add_chunks(sample_chunks, sample_embeddings)
@ -113,7 +120,7 @@ async def test_query_chunks_vector(
milvus_index, sample_chunks, sample_embeddings, embedding_dimension, mock_milvus_client milvus_index, sample_chunks, sample_embeddings, embedding_dimension, mock_milvus_client
): ):
# Setup: Add chunks first # Setup: Add chunks first
mock_milvus_client.has_collection.return_value = True mock_milvus_client.list_collections.return_value = ["test_collection"]
await milvus_index.add_chunks(sample_chunks, sample_embeddings) await milvus_index.add_chunks(sample_chunks, sample_embeddings)
# Test vector search # Test vector search
@ -126,7 +133,7 @@ async def test_query_chunks_vector(
async def test_query_chunks_keyword_search(milvus_index, sample_chunks, sample_embeddings, mock_milvus_client): async def test_query_chunks_keyword_search(milvus_index, sample_chunks, sample_embeddings, mock_milvus_client):
mock_milvus_client.has_collection.return_value = True mock_milvus_client.list_collections.return_value = ["test_collection"]
await milvus_index.add_chunks(sample_chunks, sample_embeddings) await milvus_index.add_chunks(sample_chunks, sample_embeddings)
# Test keyword search # Test keyword search
@ -139,7 +146,7 @@ async def test_query_chunks_keyword_search(milvus_index, sample_chunks, sample_e
async def test_bm25_fallback_to_simple_search(milvus_index, sample_chunks, sample_embeddings, mock_milvus_client): async def test_bm25_fallback_to_simple_search(milvus_index, sample_chunks, sample_embeddings, mock_milvus_client):
"""Test that when BM25 search fails, the system falls back to simple text search.""" """Test that when BM25 search fails, the system falls back to simple text search."""
mock_milvus_client.has_collection.return_value = True mock_milvus_client.list_collections.return_value = ["test_collection"]
await milvus_index.add_chunks(sample_chunks, sample_embeddings) await milvus_index.add_chunks(sample_chunks, sample_embeddings)
# Force BM25 search to fail # Force BM25 search to fail
@ -181,7 +188,7 @@ async def test_bm25_fallback_to_simple_search(milvus_index, sample_chunks, sampl
async def test_delete_collection(milvus_index, mock_milvus_client): async def test_delete_collection(milvus_index, mock_milvus_client):
# Test collection deletion # Test collection deletion
mock_milvus_client.has_collection.return_value = True mock_milvus_client.list_collections.return_value = ["test_collection"]
await milvus_index.delete() await milvus_index.delete()
@ -192,7 +199,7 @@ async def test_query_hybrid_search_rrf(
milvus_index, sample_chunks, sample_embeddings, embedding_dimension, mock_milvus_client milvus_index, sample_chunks, sample_embeddings, embedding_dimension, mock_milvus_client
): ):
"""Test hybrid search with RRF reranker.""" """Test hybrid search with RRF reranker."""
mock_milvus_client.has_collection.return_value = True mock_milvus_client.list_collections.return_value = ["test_collection"]
await milvus_index.add_chunks(sample_chunks, sample_embeddings) await milvus_index.add_chunks(sample_chunks, sample_embeddings)
# Mock hybrid search results # Mock hybrid search results
@ -244,7 +251,7 @@ async def test_query_hybrid_search_weighted(
milvus_index, sample_chunks, sample_embeddings, embedding_dimension, mock_milvus_client milvus_index, sample_chunks, sample_embeddings, embedding_dimension, mock_milvus_client
): ):
"""Test hybrid search with weighted reranker.""" """Test hybrid search with weighted reranker."""
mock_milvus_client.has_collection.return_value = True mock_milvus_client.list_collections.return_value = ["test_collection"]
await milvus_index.add_chunks(sample_chunks, sample_embeddings) await milvus_index.add_chunks(sample_chunks, sample_embeddings)
# Mock hybrid search results # Mock hybrid search results
@ -290,7 +297,7 @@ async def test_query_hybrid_search_default_rrf(
milvus_index, sample_chunks, sample_embeddings, embedding_dimension, mock_milvus_client milvus_index, sample_chunks, sample_embeddings, embedding_dimension, mock_milvus_client
): ):
"""Test hybrid search with default RRF reranker (no reranker_type specified).""" """Test hybrid search with default RRF reranker (no reranker_type specified)."""
mock_milvus_client.has_collection.return_value = True mock_milvus_client.list_collections.return_value = ["test_collection"]
await milvus_index.add_chunks(sample_chunks, sample_embeddings) await milvus_index.add_chunks(sample_chunks, sample_embeddings)
# Mock hybrid search results # Mock hybrid search results

View file

@ -30,12 +30,12 @@ async def test_initialize_index(vector_index):
async def test_add_chunks_query_vector(vector_index, sample_chunks, sample_embeddings): async def test_add_chunks_query_vector(vector_index, sample_chunks, sample_embeddings):
vector_index.delete() await vector_index.delete()
vector_index.initialize() await vector_index.initialize()
await vector_index.add_chunks(sample_chunks, sample_embeddings) await vector_index.add_chunks(sample_chunks, sample_embeddings)
resp = await vector_index.query_vector(sample_embeddings[0], k=1, score_threshold=-1) resp = await vector_index.query_vector(sample_embeddings[0], k=1, score_threshold=-1)
assert resp.chunks[0].content == sample_chunks[0].content assert resp.chunks[0].content == sample_chunks[0].content
vector_index.delete() await vector_index.delete()
async def test_chunk_id_conflict(vector_index, sample_chunks, embedding_dimension): async def test_chunk_id_conflict(vector_index, sample_chunks, embedding_dimension):