mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-03 19:57:35 +00:00
Merge 295d8b99c3
into d266c59c2a
This commit is contained in:
commit
cfe5ac498f
4 changed files with 151 additions and 86 deletions
|
@ -4,12 +4,11 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from numpy.typing import NDArray
|
||||
from pymilvus import AnnSearchRequest, DataType, Function, FunctionType, MilvusClient, RRFRanker, WeightedRanker
|
||||
from pymilvus import AnnSearchRequest, AsyncMilvusClient, DataType, Function, FunctionType, RRFRanker, WeightedRanker
|
||||
|
||||
from llama_stack.apis.common.errors import VectorStoreNotFoundError
|
||||
from llama_stack.apis.files.files import Files
|
||||
|
@ -48,12 +47,18 @@ OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX = f"openai_vector_stores_files_conten
|
|||
|
||||
class MilvusIndex(EmbeddingIndex):
|
||||
def __init__(
|
||||
self, client: MilvusClient, collection_name: str, consistency_level="Strong", kvstore: KVStore | None = None
|
||||
self,
|
||||
client: AsyncMilvusClient,
|
||||
collection_name: str,
|
||||
consistency_level="Strong",
|
||||
kvstore: KVStore | None = None,
|
||||
parent_adapter=None,
|
||||
):
|
||||
self.client = client
|
||||
self.collection_name = sanitize_collection_name(collection_name)
|
||||
self.consistency_level = consistency_level
|
||||
self.kvstore = kvstore
|
||||
self._parent_adapter = parent_adapter
|
||||
|
||||
async def initialize(self):
|
||||
# MilvusIndex does not require explicit initialization
|
||||
|
@ -61,15 +66,39 @@ class MilvusIndex(EmbeddingIndex):
|
|||
pass
|
||||
|
||||
async def delete(self):
|
||||
if await asyncio.to_thread(self.client.has_collection, self.collection_name):
|
||||
await asyncio.to_thread(self.client.drop_collection, collection_name=self.collection_name)
|
||||
try:
|
||||
collections = await self.client.list_collections()
|
||||
if self.collection_name in collections:
|
||||
await self.client.drop_collection(collection_name=self.collection_name)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to check or delete collection {self.collection_name}: {e}")
|
||||
|
||||
async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray):
|
||||
assert len(chunks) == len(embeddings), (
|
||||
f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}"
|
||||
)
|
||||
|
||||
if not await asyncio.to_thread(self.client.has_collection, self.collection_name):
|
||||
try:
|
||||
collections = await self.client.list_collections()
|
||||
collection_exists = self.collection_name in collections
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to check collection existence: {self.collection_name} ({e})")
|
||||
# If it's an event loop issue, try to recreate the client
|
||||
if "attached to a different loop" in str(e):
|
||||
logger.warning("Recreating client due to event loop issue")
|
||||
|
||||
if hasattr(self, "_parent_adapter"):
|
||||
await self._parent_adapter._recreate_client()
|
||||
collections = await self.client.list_collections()
|
||||
collection_exists = self.collection_name in collections
|
||||
else:
|
||||
# Assume collection doesn't exist if we can't check
|
||||
collection_exists = False
|
||||
else:
|
||||
# Assume collection doesn't exist if we can't check due to other issues
|
||||
collection_exists = False
|
||||
|
||||
if not collection_exists:
|
||||
logger.info(f"Creating new collection {self.collection_name} with nullable sparse field")
|
||||
# Create schema for vector search
|
||||
schema = self.client.create_schema()
|
||||
|
@ -123,13 +152,16 @@ class MilvusIndex(EmbeddingIndex):
|
|||
)
|
||||
schema.add_function(bm25_function)
|
||||
|
||||
await asyncio.to_thread(
|
||||
self.client.create_collection,
|
||||
self.collection_name,
|
||||
schema=schema,
|
||||
index_params=index_params,
|
||||
consistency_level=self.consistency_level,
|
||||
)
|
||||
try:
|
||||
await self.client.create_collection(
|
||||
self.collection_name,
|
||||
schema=schema,
|
||||
index_params=index_params,
|
||||
consistency_level=self.consistency_level,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create collection {self.collection_name}: {e}")
|
||||
raise e
|
||||
|
||||
data = []
|
||||
for chunk, embedding in zip(chunks, embeddings, strict=False):
|
||||
|
@ -143,8 +175,7 @@ class MilvusIndex(EmbeddingIndex):
|
|||
}
|
||||
)
|
||||
try:
|
||||
await asyncio.to_thread(
|
||||
self.client.insert,
|
||||
await self.client.insert(
|
||||
self.collection_name,
|
||||
data=data,
|
||||
)
|
||||
|
@ -153,8 +184,7 @@ class MilvusIndex(EmbeddingIndex):
|
|||
raise e
|
||||
|
||||
async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||
search_res = await asyncio.to_thread(
|
||||
self.client.search,
|
||||
search_res = await self.client.search(
|
||||
collection_name=self.collection_name,
|
||||
data=[embedding],
|
||||
anns_field="vector",
|
||||
|
@ -177,8 +207,7 @@ class MilvusIndex(EmbeddingIndex):
|
|||
"""
|
||||
try:
|
||||
# Use Milvus's built-in BM25 search
|
||||
search_res = await asyncio.to_thread(
|
||||
self.client.search,
|
||||
search_res = await self.client.search(
|
||||
collection_name=self.collection_name,
|
||||
data=[query_string], # Raw text query
|
||||
anns_field="sparse", # Use sparse field for BM25
|
||||
|
@ -219,8 +248,7 @@ class MilvusIndex(EmbeddingIndex):
|
|||
Fallback to simple text search when BM25 search is not available.
|
||||
"""
|
||||
# Simple text search using content field
|
||||
search_res = await asyncio.to_thread(
|
||||
self.client.query,
|
||||
search_res = await self.client.query(
|
||||
collection_name=self.collection_name,
|
||||
filter='content like "%{content}%"',
|
||||
filter_params={"content": query_string},
|
||||
|
@ -267,8 +295,7 @@ class MilvusIndex(EmbeddingIndex):
|
|||
impact_factor = (reranker_params or {}).get("impact_factor", 60.0)
|
||||
rerank = RRFRanker(impact_factor)
|
||||
|
||||
search_res = await asyncio.to_thread(
|
||||
self.client.hybrid_search,
|
||||
search_res = await self.client.hybrid_search(
|
||||
collection_name=self.collection_name,
|
||||
reqs=search_requests,
|
||||
ranker=rerank,
|
||||
|
@ -294,9 +321,7 @@ class MilvusIndex(EmbeddingIndex):
|
|||
try:
|
||||
# Use IN clause with square brackets and single quotes for VARCHAR field
|
||||
chunk_ids_str = ", ".join(f"'{chunk_id}'" for chunk_id in chunk_ids)
|
||||
await asyncio.to_thread(
|
||||
self.client.delete, collection_name=self.collection_name, filter=f"chunk_id in [{chunk_ids_str}]"
|
||||
)
|
||||
await self.client.delete(collection_name=self.collection_name, filter=f"chunk_id in [{chunk_ids_str}]")
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting chunks from Milvus collection {self.collection_name}: {e}")
|
||||
raise
|
||||
|
@ -321,6 +346,15 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
|||
|
||||
async def initialize(self) -> None:
|
||||
self.kvstore = await kvstore_impl(self.config.kvstore)
|
||||
|
||||
if isinstance(self.config, RemoteMilvusVectorIOConfig):
|
||||
logger.info(f"Connecting to Milvus server at {self.config.uri}")
|
||||
self.client = AsyncMilvusClient(**self.config.model_dump(exclude_none=True))
|
||||
else:
|
||||
logger.info(f"Connecting to Milvus Lite at: {self.config.db_path}")
|
||||
uri = os.path.expanduser(self.config.db_path)
|
||||
self.client = AsyncMilvusClient(uri=uri)
|
||||
|
||||
start_key = VECTOR_DBS_PREFIX
|
||||
end_key = f"{VECTOR_DBS_PREFIX}\xff"
|
||||
stored_vector_dbs = await self.kvstore.values_in_range(start_key, end_key)
|
||||
|
@ -334,23 +368,38 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
|||
collection_name=vector_db.identifier,
|
||||
consistency_level=self.config.consistency_level,
|
||||
kvstore=self.kvstore,
|
||||
parent_adapter=self,
|
||||
),
|
||||
inference_api=self.inference_api,
|
||||
)
|
||||
self.cache[vector_db.identifier] = index
|
||||
if isinstance(self.config, RemoteMilvusVectorIOConfig):
|
||||
logger.info(f"Connecting to Milvus server at {self.config.uri}")
|
||||
self.client = MilvusClient(**self.config.model_dump(exclude_none=True))
|
||||
else:
|
||||
logger.info(f"Connecting to Milvus Lite at: {self.config.db_path}")
|
||||
uri = os.path.expanduser(self.config.db_path)
|
||||
self.client = MilvusClient(uri=uri)
|
||||
|
||||
# Load existing OpenAI vector stores into the in-memory cache
|
||||
await self.initialize_openai_vector_stores()
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
self.client.close()
|
||||
if self.client:
|
||||
await self.client.close()
|
||||
|
||||
async def _recreate_client(self) -> None:
|
||||
"""Recreate the AsyncMilvusClient when event loop issues occur"""
|
||||
try:
|
||||
if self.client:
|
||||
await self.client.close()
|
||||
except Exception as e:
|
||||
logger.warning(f"Error closing old client: {e}")
|
||||
|
||||
if isinstance(self.config, RemoteMilvusVectorIOConfig):
|
||||
logger.info(f"Recreating connection to Milvus server at {self.config.uri}")
|
||||
self.client = AsyncMilvusClient(**self.config.model_dump(exclude_none=True))
|
||||
else:
|
||||
logger.info(f"Recreating connection to Milvus Lite at: {self.config.db_path}")
|
||||
uri = os.path.expanduser(self.config.db_path)
|
||||
self.client = AsyncMilvusClient(uri=uri)
|
||||
|
||||
for index_wrapper in self.cache.values():
|
||||
if hasattr(index_wrapper, "index") and hasattr(index_wrapper.index, "client"):
|
||||
index_wrapper.index.client = self.client
|
||||
|
||||
async def register_vector_db(
|
||||
self,
|
||||
|
@ -362,7 +411,12 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
|||
consistency_level = "Strong"
|
||||
index = VectorDBWithIndex(
|
||||
vector_db=vector_db,
|
||||
index=MilvusIndex(self.client, vector_db.identifier, consistency_level=consistency_level),
|
||||
index=MilvusIndex(
|
||||
client=self.client,
|
||||
collection_name=vector_db.identifier,
|
||||
consistency_level=consistency_level,
|
||||
parent_adapter=self,
|
||||
),
|
||||
inference_api=self.inference_api,
|
||||
)
|
||||
|
||||
|
@ -381,7 +435,9 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
|||
|
||||
index = VectorDBWithIndex(
|
||||
vector_db=vector_db,
|
||||
index=MilvusIndex(client=self.client, collection_name=vector_db.identifier, kvstore=self.kvstore),
|
||||
index=MilvusIndex(
|
||||
client=self.client, collection_name=vector_db.identifier, kvstore=self.kvstore, parent_adapter=self
|
||||
),
|
||||
inference_api=self.inference_api,
|
||||
)
|
||||
self.cache[vector_db_id] = index
|
||||
|
|
|
@ -10,7 +10,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
|||
import numpy as np
|
||||
import pytest
|
||||
from chromadb import PersistentClient
|
||||
from pymilvus import MilvusClient, connections
|
||||
from pymilvus import AsyncMilvusClient, connections
|
||||
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.apis.vector_io import Chunk, ChunkMetadata, QueryChunksResponse
|
||||
|
@ -141,7 +141,7 @@ async def sqlite_vec_vec_index(embedding_dimension, tmp_path_factory):
|
|||
await index.initialize()
|
||||
index.db_path = db_path
|
||||
yield index
|
||||
index.delete()
|
||||
await index.delete()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
@ -178,13 +178,15 @@ def milvus_vec_db_path(tmp_path_factory):
|
|||
|
||||
@pytest.fixture
|
||||
async def milvus_vec_index(milvus_vec_db_path, embedding_dimension):
|
||||
client = MilvusClient(milvus_vec_db_path)
|
||||
client = AsyncMilvusClient(milvus_vec_db_path)
|
||||
name = f"{COLLECTION_PREFIX}_{np.random.randint(1e6)}"
|
||||
connections.connect(alias=MILVUS_ALIAS, uri=milvus_vec_db_path)
|
||||
index = MilvusIndex(client, name, consistency_level="Strong")
|
||||
index.db_path = milvus_vec_db_path
|
||||
yield index
|
||||
|
||||
await client.close()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def milvus_vec_adapter(milvus_vec_db_path, mock_inference_api):
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
@ -14,7 +14,7 @@ from llama_stack.apis.vector_io import QueryChunksResponse
|
|||
# Mock the entire pymilvus module
|
||||
pymilvus_mock = MagicMock()
|
||||
pymilvus_mock.DataType = MagicMock()
|
||||
pymilvus_mock.MilvusClient = MagicMock
|
||||
pymilvus_mock.AsyncMilvusClient = MagicMock
|
||||
pymilvus_mock.RRFRanker = MagicMock
|
||||
pymilvus_mock.WeightedRanker = MagicMock
|
||||
pymilvus_mock.AnnSearchRequest = MagicMock
|
||||
|
@ -40,48 +40,55 @@ async def mock_milvus_client() -> MagicMock:
|
|||
"""Create a mock Milvus client with common method behaviors."""
|
||||
client = MagicMock()
|
||||
|
||||
# Mock collection operations
|
||||
client.has_collection.return_value = False # Initially no collection
|
||||
client.create_collection.return_value = None
|
||||
client.drop_collection.return_value = None
|
||||
client.list_collections = AsyncMock(return_value=[]) # Initially no collections
|
||||
client.create_collection = AsyncMock(return_value=None)
|
||||
client.drop_collection = AsyncMock(return_value=None)
|
||||
|
||||
# Mock insert operation
|
||||
client.insert.return_value = {"insert_count": 10}
|
||||
client.insert = AsyncMock(return_value={"insert_count": 10})
|
||||
|
||||
# Mock search operation - return mock results (data should be dict, not JSON string)
|
||||
client.search.return_value = [
|
||||
[
|
||||
client.search = AsyncMock(
|
||||
return_value=[
|
||||
[
|
||||
{
|
||||
"id": 0,
|
||||
"distance": 0.1,
|
||||
"entity": {"chunk_content": {"content": "mock chunk 1", "metadata": {"document_id": "doc1"}}},
|
||||
},
|
||||
{
|
||||
"id": 1,
|
||||
"distance": 0.2,
|
||||
"entity": {"chunk_content": {"content": "mock chunk 2", "metadata": {"document_id": "doc2"}}},
|
||||
},
|
||||
]
|
||||
]
|
||||
)
|
||||
|
||||
# Mock async query operation for keyword search (data should be dict, not JSON string)
|
||||
client.query = AsyncMock(
|
||||
return_value=[
|
||||
{
|
||||
"id": 0,
|
||||
"distance": 0.1,
|
||||
"entity": {"chunk_content": {"content": "mock chunk 1", "metadata": {"document_id": "doc1"}}},
|
||||
"chunk_id": "chunk1",
|
||||
"chunk_content": {"content": "mock chunk 1", "metadata": {"document_id": "doc1"}},
|
||||
"score": 0.9,
|
||||
},
|
||||
{
|
||||
"id": 1,
|
||||
"distance": 0.2,
|
||||
"entity": {"chunk_content": {"content": "mock chunk 2", "metadata": {"document_id": "doc2"}}},
|
||||
"chunk_id": "chunk2",
|
||||
"chunk_content": {"content": "mock chunk 2", "metadata": {"document_id": "doc2"}},
|
||||
"score": 0.8,
|
||||
},
|
||||
{
|
||||
"chunk_id": "chunk3",
|
||||
"chunk_content": {"content": "mock chunk 3", "metadata": {"document_id": "doc3"}},
|
||||
"score": 0.7,
|
||||
},
|
||||
]
|
||||
]
|
||||
)
|
||||
|
||||
# Mock query operation for keyword search (data should be dict, not JSON string)
|
||||
client.query.return_value = [
|
||||
{
|
||||
"chunk_id": "chunk1",
|
||||
"chunk_content": {"content": "mock chunk 1", "metadata": {"document_id": "doc1"}},
|
||||
"score": 0.9,
|
||||
},
|
||||
{
|
||||
"chunk_id": "chunk2",
|
||||
"chunk_content": {"content": "mock chunk 2", "metadata": {"document_id": "doc2"}},
|
||||
"score": 0.8,
|
||||
},
|
||||
{
|
||||
"chunk_id": "chunk3",
|
||||
"chunk_content": {"content": "mock chunk 3", "metadata": {"document_id": "doc3"}},
|
||||
"score": 0.7,
|
||||
},
|
||||
]
|
||||
client.hybrid_search = AsyncMock(return_value=[])
|
||||
|
||||
client.delete = AsyncMock(return_value=None)
|
||||
|
||||
client.close = AsyncMock(return_value=None)
|
||||
|
||||
return client
|
||||
|
||||
|
@ -96,7 +103,7 @@ async def milvus_index(mock_milvus_client):
|
|||
|
||||
async def test_add_chunks(milvus_index, sample_chunks, sample_embeddings, mock_milvus_client):
|
||||
# Setup: collection doesn't exist initially, then exists after creation
|
||||
mock_milvus_client.has_collection.side_effect = [False, True]
|
||||
mock_milvus_client.list_collections.side_effect = [[], ["test_collection"]]
|
||||
|
||||
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
|
||||
|
||||
|
@ -113,7 +120,7 @@ async def test_query_chunks_vector(
|
|||
milvus_index, sample_chunks, sample_embeddings, embedding_dimension, mock_milvus_client
|
||||
):
|
||||
# Setup: Add chunks first
|
||||
mock_milvus_client.has_collection.return_value = True
|
||||
mock_milvus_client.list_collections.return_value = ["test_collection"]
|
||||
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
|
||||
|
||||
# Test vector search
|
||||
|
@ -126,7 +133,7 @@ async def test_query_chunks_vector(
|
|||
|
||||
|
||||
async def test_query_chunks_keyword_search(milvus_index, sample_chunks, sample_embeddings, mock_milvus_client):
|
||||
mock_milvus_client.has_collection.return_value = True
|
||||
mock_milvus_client.list_collections.return_value = ["test_collection"]
|
||||
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
|
||||
|
||||
# Test keyword search
|
||||
|
@ -139,7 +146,7 @@ async def test_query_chunks_keyword_search(milvus_index, sample_chunks, sample_e
|
|||
|
||||
async def test_bm25_fallback_to_simple_search(milvus_index, sample_chunks, sample_embeddings, mock_milvus_client):
|
||||
"""Test that when BM25 search fails, the system falls back to simple text search."""
|
||||
mock_milvus_client.has_collection.return_value = True
|
||||
mock_milvus_client.list_collections.return_value = ["test_collection"]
|
||||
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
|
||||
|
||||
# Force BM25 search to fail
|
||||
|
@ -181,7 +188,7 @@ async def test_bm25_fallback_to_simple_search(milvus_index, sample_chunks, sampl
|
|||
|
||||
async def test_delete_collection(milvus_index, mock_milvus_client):
|
||||
# Test collection deletion
|
||||
mock_milvus_client.has_collection.return_value = True
|
||||
mock_milvus_client.list_collections.return_value = ["test_collection"]
|
||||
|
||||
await milvus_index.delete()
|
||||
|
||||
|
@ -192,7 +199,7 @@ async def test_query_hybrid_search_rrf(
|
|||
milvus_index, sample_chunks, sample_embeddings, embedding_dimension, mock_milvus_client
|
||||
):
|
||||
"""Test hybrid search with RRF reranker."""
|
||||
mock_milvus_client.has_collection.return_value = True
|
||||
mock_milvus_client.list_collections.return_value = ["test_collection"]
|
||||
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
|
||||
|
||||
# Mock hybrid search results
|
||||
|
@ -244,7 +251,7 @@ async def test_query_hybrid_search_weighted(
|
|||
milvus_index, sample_chunks, sample_embeddings, embedding_dimension, mock_milvus_client
|
||||
):
|
||||
"""Test hybrid search with weighted reranker."""
|
||||
mock_milvus_client.has_collection.return_value = True
|
||||
mock_milvus_client.list_collections.return_value = ["test_collection"]
|
||||
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
|
||||
|
||||
# Mock hybrid search results
|
||||
|
@ -290,7 +297,7 @@ async def test_query_hybrid_search_default_rrf(
|
|||
milvus_index, sample_chunks, sample_embeddings, embedding_dimension, mock_milvus_client
|
||||
):
|
||||
"""Test hybrid search with default RRF reranker (no reranker_type specified)."""
|
||||
mock_milvus_client.has_collection.return_value = True
|
||||
mock_milvus_client.list_collections.return_value = ["test_collection"]
|
||||
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
|
||||
|
||||
# Mock hybrid search results
|
||||
|
|
|
@ -30,12 +30,12 @@ async def test_initialize_index(vector_index):
|
|||
|
||||
|
||||
async def test_add_chunks_query_vector(vector_index, sample_chunks, sample_embeddings):
|
||||
vector_index.delete()
|
||||
vector_index.initialize()
|
||||
await vector_index.delete()
|
||||
await vector_index.initialize()
|
||||
await vector_index.add_chunks(sample_chunks, sample_embeddings)
|
||||
resp = await vector_index.query_vector(sample_embeddings[0], k=1, score_threshold=-1)
|
||||
assert resp.chunks[0].content == sample_chunks[0].content
|
||||
vector_index.delete()
|
||||
await vector_index.delete()
|
||||
|
||||
|
||||
async def test_chunk_id_conflict(vector_index, sample_chunks, embedding_dimension):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue