mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 01:48:05 +00:00
Merge a6eed99790 into ee107aadd6
This commit is contained in:
commit
154b7f568f
7 changed files with 547 additions and 40 deletions
|
|
@ -125,6 +125,8 @@ unit = [
|
||||||
"together",
|
"together",
|
||||||
"coverage",
|
"coverage",
|
||||||
"moto[s3]>=5.1.10",
|
"moto[s3]>=5.1.10",
|
||||||
|
"pymilvus>=2.6.1",
|
||||||
|
"milvus-lite>=2.5.0",
|
||||||
]
|
]
|
||||||
# These are the core dependencies required for running integration tests. They are shared across all
|
# These are the core dependencies required for running integration tests. They are shared across all
|
||||||
# providers. If a provider requires additional dependencies, please add them to your environment
|
# providers. If a provider requires additional dependencies, please add them to your environment
|
||||||
|
|
|
||||||
|
|
@ -4,12 +4,11 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import os
|
import os
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from numpy.typing import NDArray
|
from numpy.typing import NDArray
|
||||||
from pymilvus import AnnSearchRequest, DataType, Function, FunctionType, MilvusClient, RRFRanker, WeightedRanker
|
from pymilvus import AnnSearchRequest, AsyncMilvusClient, DataType, Function, FunctionType, RRFRanker, WeightedRanker
|
||||||
|
|
||||||
from llama_stack.core.storage.kvstore import kvstore_impl
|
from llama_stack.core.storage.kvstore import kvstore_impl
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
|
|
@ -49,12 +48,18 @@ OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX = f"openai_vector_stores_files_conten
|
||||||
|
|
||||||
class MilvusIndex(EmbeddingIndex):
|
class MilvusIndex(EmbeddingIndex):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, client: MilvusClient, collection_name: str, consistency_level="Strong", kvstore: KVStore | None = None
|
self,
|
||||||
|
client: AsyncMilvusClient,
|
||||||
|
collection_name: str,
|
||||||
|
consistency_level="Strong",
|
||||||
|
kvstore: KVStore | None = None,
|
||||||
|
parent_adapter=None,
|
||||||
):
|
):
|
||||||
self.client = client
|
self.client = client
|
||||||
self.collection_name = sanitize_collection_name(collection_name)
|
self.collection_name = sanitize_collection_name(collection_name)
|
||||||
self.consistency_level = consistency_level
|
self.consistency_level = consistency_level
|
||||||
self.kvstore = kvstore
|
self.kvstore = kvstore
|
||||||
|
self._parent_adapter = parent_adapter
|
||||||
|
|
||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
# MilvusIndex does not require explicit initialization
|
# MilvusIndex does not require explicit initialization
|
||||||
|
|
@ -62,15 +67,39 @@ class MilvusIndex(EmbeddingIndex):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def delete(self):
|
async def delete(self):
|
||||||
if await asyncio.to_thread(self.client.has_collection, self.collection_name):
|
try:
|
||||||
await asyncio.to_thread(self.client.drop_collection, collection_name=self.collection_name)
|
collections = await self.client.list_collections()
|
||||||
|
if self.collection_name in collections:
|
||||||
|
await self.client.drop_collection(collection_name=self.collection_name)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to check or delete collection {self.collection_name}: {e}")
|
||||||
|
|
||||||
async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray):
|
async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray):
|
||||||
assert len(chunks) == len(embeddings), (
|
assert len(chunks) == len(embeddings), (
|
||||||
f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}"
|
f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
if not await asyncio.to_thread(self.client.has_collection, self.collection_name):
|
try:
|
||||||
|
collections = await self.client.list_collections()
|
||||||
|
collection_exists = self.collection_name in collections
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to check collection existence: {self.collection_name} ({e})")
|
||||||
|
# If it's an event loop issue, try to recreate the client
|
||||||
|
if "attached to a different loop" in str(e):
|
||||||
|
logger.warning("Recreating client due to event loop issue")
|
||||||
|
|
||||||
|
if hasattr(self, "_parent_adapter"):
|
||||||
|
await self._parent_adapter._recreate_client()
|
||||||
|
collections = await self.client.list_collections()
|
||||||
|
collection_exists = self.collection_name in collections
|
||||||
|
else:
|
||||||
|
# Assume collection doesn't exist if we can't check
|
||||||
|
collection_exists = False
|
||||||
|
else:
|
||||||
|
# Assume collection doesn't exist if we can't check due to other issues
|
||||||
|
collection_exists = False
|
||||||
|
|
||||||
|
if not collection_exists:
|
||||||
logger.info(f"Creating new collection {self.collection_name} with nullable sparse field")
|
logger.info(f"Creating new collection {self.collection_name} with nullable sparse field")
|
||||||
# Create schema for vector search
|
# Create schema for vector search
|
||||||
schema = self.client.create_schema()
|
schema = self.client.create_schema()
|
||||||
|
|
@ -101,13 +130,16 @@ class MilvusIndex(EmbeddingIndex):
|
||||||
)
|
)
|
||||||
schema.add_function(bm25_function)
|
schema.add_function(bm25_function)
|
||||||
|
|
||||||
await asyncio.to_thread(
|
try:
|
||||||
self.client.create_collection,
|
await self.client.create_collection(
|
||||||
self.collection_name,
|
self.collection_name,
|
||||||
schema=schema,
|
schema=schema,
|
||||||
index_params=index_params,
|
index_params=index_params,
|
||||||
consistency_level=self.consistency_level,
|
consistency_level=self.consistency_level,
|
||||||
)
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to create collection {self.collection_name}: {e}")
|
||||||
|
raise e
|
||||||
|
|
||||||
data = []
|
data = []
|
||||||
for chunk, embedding in zip(chunks, embeddings, strict=False):
|
for chunk, embedding in zip(chunks, embeddings, strict=False):
|
||||||
|
|
@ -121,14 +153,16 @@ class MilvusIndex(EmbeddingIndex):
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
await asyncio.to_thread(self.client.insert, self.collection_name, data=data)
|
await self.client.insert(
|
||||||
|
self.collection_name,
|
||||||
|
data=data,
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error inserting chunks into Milvus collection {self.collection_name}: {e}")
|
logger.error(f"Error inserting chunks into Milvus collection {self.collection_name}: {e}")
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||||
search_res = await asyncio.to_thread(
|
search_res = await self.client.search(
|
||||||
self.client.search,
|
|
||||||
collection_name=self.collection_name,
|
collection_name=self.collection_name,
|
||||||
data=[embedding],
|
data=[embedding],
|
||||||
anns_field="vector",
|
anns_field="vector",
|
||||||
|
|
@ -146,8 +180,7 @@ class MilvusIndex(EmbeddingIndex):
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# Use Milvus's built-in BM25 search
|
# Use Milvus's built-in BM25 search
|
||||||
search_res = await asyncio.to_thread(
|
search_res = await self.client.search(
|
||||||
self.client.search,
|
|
||||||
collection_name=self.collection_name,
|
collection_name=self.collection_name,
|
||||||
data=[query_string], # Raw text query
|
data=[query_string], # Raw text query
|
||||||
anns_field="sparse", # Use sparse field for BM25
|
anns_field="sparse", # Use sparse field for BM25
|
||||||
|
|
@ -183,8 +216,7 @@ class MilvusIndex(EmbeddingIndex):
|
||||||
Fallback to simple text search when BM25 search is not available.
|
Fallback to simple text search when BM25 search is not available.
|
||||||
"""
|
"""
|
||||||
# Simple text search using content field
|
# Simple text search using content field
|
||||||
search_res = await asyncio.to_thread(
|
search_res = await self.client.query(
|
||||||
self.client.query,
|
|
||||||
collection_name=self.collection_name,
|
collection_name=self.collection_name,
|
||||||
filter='content like "%{content}%"',
|
filter='content like "%{content}%"',
|
||||||
filter_params={"content": query_string},
|
filter_params={"content": query_string},
|
||||||
|
|
@ -231,8 +263,7 @@ class MilvusIndex(EmbeddingIndex):
|
||||||
impact_factor = (reranker_params or {}).get("impact_factor", 60.0)
|
impact_factor = (reranker_params or {}).get("impact_factor", 60.0)
|
||||||
rerank = RRFRanker(impact_factor)
|
rerank = RRFRanker(impact_factor)
|
||||||
|
|
||||||
search_res = await asyncio.to_thread(
|
search_res = await self.client.hybrid_search(
|
||||||
self.client.hybrid_search,
|
|
||||||
collection_name=self.collection_name,
|
collection_name=self.collection_name,
|
||||||
reqs=search_requests,
|
reqs=search_requests,
|
||||||
ranker=rerank,
|
ranker=rerank,
|
||||||
|
|
@ -258,9 +289,7 @@ class MilvusIndex(EmbeddingIndex):
|
||||||
try:
|
try:
|
||||||
# Use IN clause with square brackets and single quotes for VARCHAR field
|
# Use IN clause with square brackets and single quotes for VARCHAR field
|
||||||
chunk_ids_str = ", ".join(f"'{chunk_id}'" for chunk_id in chunk_ids)
|
chunk_ids_str = ", ".join(f"'{chunk_id}'" for chunk_id in chunk_ids)
|
||||||
await asyncio.to_thread(
|
await self.client.delete(collection_name=self.collection_name, filter=f"chunk_id in [{chunk_ids_str}]")
|
||||||
self.client.delete, collection_name=self.collection_name, filter=f"chunk_id in [{chunk_ids_str}]"
|
|
||||||
)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error deleting chunks from Milvus collection {self.collection_name}: {e}")
|
logger.error(f"Error deleting chunks from Milvus collection {self.collection_name}: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
@ -283,6 +312,15 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoc
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
self.kvstore = await kvstore_impl(self.config.persistence)
|
self.kvstore = await kvstore_impl(self.config.persistence)
|
||||||
|
|
||||||
|
if isinstance(self.config, RemoteMilvusVectorIOConfig):
|
||||||
|
logger.info(f"Connecting to Milvus server at {self.config.uri}")
|
||||||
|
self.client = AsyncMilvusClient(**self.config.model_dump(exclude_none=True))
|
||||||
|
else:
|
||||||
|
logger.info(f"Connecting to Milvus Lite at: {self.config.db_path}")
|
||||||
|
uri = os.path.expanduser(self.config.db_path)
|
||||||
|
self.client = AsyncMilvusClient(uri=uri)
|
||||||
|
|
||||||
start_key = VECTOR_DBS_PREFIX
|
start_key = VECTOR_DBS_PREFIX
|
||||||
end_key = f"{VECTOR_DBS_PREFIX}\xff"
|
end_key = f"{VECTOR_DBS_PREFIX}\xff"
|
||||||
stored_vector_stores = await self.kvstore.values_in_range(start_key, end_key)
|
stored_vector_stores = await self.kvstore.values_in_range(start_key, end_key)
|
||||||
|
|
@ -296,26 +334,41 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoc
|
||||||
collection_name=vector_store.identifier,
|
collection_name=vector_store.identifier,
|
||||||
consistency_level=self.config.consistency_level,
|
consistency_level=self.config.consistency_level,
|
||||||
kvstore=self.kvstore,
|
kvstore=self.kvstore,
|
||||||
|
parent_adapter=self,
|
||||||
),
|
),
|
||||||
inference_api=self.inference_api,
|
inference_api=self.inference_api,
|
||||||
)
|
)
|
||||||
self.cache[vector_store.identifier] = index
|
self.cache[vector_store.identifier] = index
|
||||||
if isinstance(self.config, RemoteMilvusVectorIOConfig):
|
|
||||||
logger.info(f"Connecting to Milvus server at {self.config.uri}")
|
|
||||||
self.client = MilvusClient(**self.config.model_dump(exclude_none=True))
|
|
||||||
else:
|
|
||||||
logger.info(f"Connecting to Milvus Lite at: {self.config.db_path}")
|
|
||||||
uri = os.path.expanduser(self.config.db_path)
|
|
||||||
self.client = MilvusClient(uri=uri)
|
|
||||||
|
|
||||||
# Load existing OpenAI vector stores into the in-memory cache
|
# Load existing OpenAI vector stores into the in-memory cache
|
||||||
await self.initialize_openai_vector_stores()
|
await self.initialize_openai_vector_stores()
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
self.client.close()
|
if self.client:
|
||||||
|
await self.client.close()
|
||||||
# Clean up mixin resources (file batch tasks)
|
# Clean up mixin resources (file batch tasks)
|
||||||
await super().shutdown()
|
await super().shutdown()
|
||||||
|
|
||||||
|
async def _recreate_client(self) -> None:
|
||||||
|
"""Recreate the AsyncMilvusClient when event loop issues occur"""
|
||||||
|
try:
|
||||||
|
if self.client:
|
||||||
|
await self.client.close()
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error closing old client: {e}")
|
||||||
|
|
||||||
|
if isinstance(self.config, RemoteMilvusVectorIOConfig):
|
||||||
|
logger.info(f"Recreating connection to Milvus server at {self.config.uri}")
|
||||||
|
self.client = AsyncMilvusClient(**self.config.model_dump(exclude_none=True))
|
||||||
|
else:
|
||||||
|
logger.info(f"Recreating connection to Milvus Lite at: {self.config.db_path}")
|
||||||
|
uri = os.path.expanduser(self.config.db_path)
|
||||||
|
self.client = AsyncMilvusClient(uri=uri)
|
||||||
|
|
||||||
|
for index_wrapper in self.cache.values():
|
||||||
|
if hasattr(index_wrapper, "index") and hasattr(index_wrapper.index, "client"):
|
||||||
|
index_wrapper.index.client = self.client
|
||||||
|
|
||||||
async def register_vector_store(self, vector_store: VectorStore) -> None:
|
async def register_vector_store(self, vector_store: VectorStore) -> None:
|
||||||
if isinstance(self.config, RemoteMilvusVectorIOConfig):
|
if isinstance(self.config, RemoteMilvusVectorIOConfig):
|
||||||
consistency_level = self.config.consistency_level
|
consistency_level = self.config.consistency_level
|
||||||
|
|
@ -323,7 +376,13 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoc
|
||||||
consistency_level = "Strong"
|
consistency_level = "Strong"
|
||||||
index = VectorStoreWithIndex(
|
index = VectorStoreWithIndex(
|
||||||
vector_store=vector_store,
|
vector_store=vector_store,
|
||||||
index=MilvusIndex(self.client, vector_store.identifier, consistency_level=consistency_level),
|
index=MilvusIndex(
|
||||||
|
client=self.client,
|
||||||
|
collection_name=vector_store.identifier,
|
||||||
|
consistency_level=consistency_level,
|
||||||
|
kvstore=self.kvstore,
|
||||||
|
parent_adapter=self,
|
||||||
|
),
|
||||||
inference_api=self.inference_api,
|
inference_api=self.inference_api,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -345,7 +404,12 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoc
|
||||||
vector_store = VectorStore.model_validate_json(vector_store_data)
|
vector_store = VectorStore.model_validate_json(vector_store_data)
|
||||||
index = VectorStoreWithIndex(
|
index = VectorStoreWithIndex(
|
||||||
vector_store=vector_store,
|
vector_store=vector_store,
|
||||||
index=MilvusIndex(client=self.client, collection_name=vector_store.identifier, kvstore=self.kvstore),
|
index=MilvusIndex(
|
||||||
|
client=self.client,
|
||||||
|
collection_name=vector_store.identifier,
|
||||||
|
kvstore=self.kvstore,
|
||||||
|
parent_adapter=self,
|
||||||
|
),
|
||||||
inference_api=self.inference_api,
|
inference_api=self.inference_api,
|
||||||
)
|
)
|
||||||
self.cache[vector_store_id] = index
|
self.cache[vector_store_id] = index
|
||||||
|
|
|
||||||
8
tests/conftest.py
Normal file
8
tests/conftest.py
Normal file
|
|
@ -0,0 +1,8 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
# This file intentionally left empty - pytest will auto-discover conftest.py files
|
||||||
|
# in subdirectories and load their fixtures automatically.
|
||||||
|
|
@ -9,19 +9,23 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
|
from pymilvus import AsyncMilvusClient, connections
|
||||||
|
|
||||||
from llama_stack.core.storage.datatypes import KVStoreReference, SqliteKVStoreConfig
|
from llama_stack.core.storage.datatypes import KVStoreReference, SqliteKVStoreConfig
|
||||||
from llama_stack.core.storage.kvstore import register_kvstore_backends
|
from llama_stack.core.storage.kvstore import register_kvstore_backends
|
||||||
from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig
|
from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig
|
||||||
from llama_stack.providers.inline.vector_io.faiss.faiss import FaissIndex, FaissVectorIOAdapter
|
from llama_stack.providers.inline.vector_io.faiss.faiss import FaissIndex, FaissVectorIOAdapter
|
||||||
|
from llama_stack.providers.inline.vector_io.milvus.config import MilvusVectorIOConfig
|
||||||
from llama_stack.providers.inline.vector_io.sqlite_vec import SQLiteVectorIOConfig
|
from llama_stack.providers.inline.vector_io.sqlite_vec import SQLiteVectorIOConfig
|
||||||
from llama_stack.providers.inline.vector_io.sqlite_vec.sqlite_vec import SQLiteVecIndex, SQLiteVecVectorIOAdapter
|
from llama_stack.providers.inline.vector_io.sqlite_vec.sqlite_vec import SQLiteVecIndex, SQLiteVecVectorIOAdapter
|
||||||
|
from llama_stack.providers.remote.vector_io.milvus.milvus import MilvusIndex, MilvusVectorIOAdapter
|
||||||
from llama_stack.providers.remote.vector_io.pgvector.config import PGVectorVectorIOConfig
|
from llama_stack.providers.remote.vector_io.pgvector.config import PGVectorVectorIOConfig
|
||||||
from llama_stack.providers.remote.vector_io.pgvector.pgvector import PGVectorIndex, PGVectorVectorIOAdapter
|
from llama_stack.providers.remote.vector_io.pgvector.pgvector import PGVectorIndex, PGVectorVectorIOAdapter
|
||||||
from llama_stack_api import Chunk, ChunkMetadata, QueryChunksResponse, VectorStore
|
from llama_stack_api import Chunk, ChunkMetadata, QueryChunksResponse, VectorStore
|
||||||
|
|
||||||
EMBEDDING_DIMENSION = 768
|
EMBEDDING_DIMENSION = 768
|
||||||
COLLECTION_PREFIX = "test_collection"
|
COLLECTION_PREFIX = "test_collection"
|
||||||
|
MILVUS_ALIAS = "test_milvus"
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(params=["sqlite_vec", "faiss", "pgvector"])
|
@pytest.fixture(params=["sqlite_vec", "faiss", "pgvector"])
|
||||||
|
|
@ -140,7 +144,7 @@ async def sqlite_vec_vec_index(embedding_dimension, tmp_path_factory):
|
||||||
await index.initialize()
|
await index.initialize()
|
||||||
index.db_path = db_path
|
index.db_path = db_path
|
||||||
yield index
|
yield index
|
||||||
index.delete()
|
await index.delete()
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
|
@ -169,6 +173,48 @@ async def sqlite_vec_adapter(sqlite_vec_db_path, unique_kvstore_config, mock_inf
|
||||||
await adapter.shutdown()
|
await adapter.shutdown()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def milvus_vec_db_path(tmp_path_factory):
|
||||||
|
db_path = str(tmp_path_factory.getbasetemp() / "test_milvus.db")
|
||||||
|
return db_path
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def milvus_vec_index(milvus_vec_db_path, embedding_dimension):
|
||||||
|
client = AsyncMilvusClient(milvus_vec_db_path)
|
||||||
|
name = f"{COLLECTION_PREFIX}_{np.random.randint(1e6)}"
|
||||||
|
connections.connect(alias=MILVUS_ALIAS, uri=milvus_vec_db_path)
|
||||||
|
index = MilvusIndex(client, name, consistency_level="Strong")
|
||||||
|
index.db_path = milvus_vec_db_path
|
||||||
|
yield index
|
||||||
|
|
||||||
|
await client.close()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def milvus_vec_adapter(milvus_vec_db_path, mock_inference_api):
|
||||||
|
config = MilvusVectorIOConfig(
|
||||||
|
db_path=milvus_vec_db_path,
|
||||||
|
kvstore=SqliteKVStoreConfig(),
|
||||||
|
)
|
||||||
|
adapter = MilvusVectorIOAdapter(
|
||||||
|
config=config,
|
||||||
|
inference_api=mock_inference_api,
|
||||||
|
files_api=None,
|
||||||
|
)
|
||||||
|
await adapter.initialize()
|
||||||
|
await adapter.register_vector_store(
|
||||||
|
VectorStore(
|
||||||
|
identifier=adapter.metadata_collection_name,
|
||||||
|
provider_id="test_provider",
|
||||||
|
embedding_model="test_model",
|
||||||
|
embedding_dimension=128,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
yield adapter
|
||||||
|
await adapter.shutdown()
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def faiss_vec_db_path(tmp_path_factory):
|
def faiss_vec_db_path(tmp_path_factory):
|
||||||
db_path = str(tmp_path_factory.getbasetemp() / "test_faiss.db")
|
db_path = str(tmp_path_factory.getbasetemp() / "test_faiss.db")
|
||||||
|
|
|
||||||
383
tests/unit/providers/vector_io/remote/test_milvus.py
Normal file
383
tests/unit/providers/vector_io/remote/test_milvus.py
Normal file
|
|
@ -0,0 +1,383 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from llama_stack.apis.vector_io import QueryChunksResponse
|
||||||
|
|
||||||
|
# Mock the entire pymilvus module
|
||||||
|
pymilvus_mock = MagicMock()
|
||||||
|
pymilvus_mock.DataType = MagicMock()
|
||||||
|
pymilvus_mock.AsyncMilvusClient = MagicMock
|
||||||
|
pymilvus_mock.RRFRanker = MagicMock
|
||||||
|
pymilvus_mock.WeightedRanker = MagicMock
|
||||||
|
pymilvus_mock.AnnSearchRequest = MagicMock
|
||||||
|
|
||||||
|
# Apply the mock before importing MilvusIndex
|
||||||
|
with patch.dict("sys.modules", {"pymilvus": pymilvus_mock}):
|
||||||
|
from llama_stack.providers.remote.vector_io.milvus.milvus import MilvusIndex
|
||||||
|
|
||||||
|
# This test is a unit test for the MilvusVectorIOAdapter class. This should only contain
|
||||||
|
# tests which are specific to this class. More general (API-level) tests should be placed in
|
||||||
|
# tests/integration/vector_io/
|
||||||
|
#
|
||||||
|
# How to run this test:
|
||||||
|
#
|
||||||
|
# pytest tests/unit/providers/vector_io/test_milvus.py \
|
||||||
|
# -v -s --tb=short --disable-warnings --asyncio-mode=auto
|
||||||
|
|
||||||
|
MILVUS_PROVIDER = "milvus"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def mock_milvus_client() -> MagicMock:
|
||||||
|
"""Create a mock Milvus client with common method behaviors."""
|
||||||
|
client = MagicMock()
|
||||||
|
|
||||||
|
client.list_collections = AsyncMock(return_value=[]) # Initially no collections
|
||||||
|
client.create_collection = AsyncMock(return_value=None)
|
||||||
|
client.drop_collection = AsyncMock(return_value=None)
|
||||||
|
|
||||||
|
client.insert = AsyncMock(return_value={"insert_count": 10})
|
||||||
|
|
||||||
|
client.search = AsyncMock(
|
||||||
|
return_value=[
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"id": 0,
|
||||||
|
"distance": 0.1,
|
||||||
|
"entity": {
|
||||||
|
"chunk_content": {
|
||||||
|
"chunk_id": "chunk1",
|
||||||
|
"content": "mock chunk 1",
|
||||||
|
"metadata": {"document_id": "doc1"},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1,
|
||||||
|
"distance": 0.2,
|
||||||
|
"entity": {
|
||||||
|
"chunk_content": {
|
||||||
|
"chunk_id": "chunk2",
|
||||||
|
"content": "mock chunk 2",
|
||||||
|
"metadata": {"document_id": "doc2"},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mock async query operation for keyword search (data should be dict, not JSON string)
|
||||||
|
client.query = AsyncMock(
|
||||||
|
return_value=[
|
||||||
|
{
|
||||||
|
"chunk_id": "chunk1",
|
||||||
|
"chunk_content": {"chunk_id": "chunk1", "content": "mock chunk 1", "metadata": {"document_id": "doc1"}},
|
||||||
|
"score": 0.9,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"chunk_id": "chunk2",
|
||||||
|
"chunk_content": {"chunk_id": "chunk2", "content": "mock chunk 2", "metadata": {"document_id": "doc2"}},
|
||||||
|
"score": 0.8,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"chunk_id": "chunk3",
|
||||||
|
"chunk_content": {"chunk_id": "chunk3", "content": "mock chunk 3", "metadata": {"document_id": "doc3"}},
|
||||||
|
"score": 0.7,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
client.hybrid_search = AsyncMock(return_value=[])
|
||||||
|
|
||||||
|
client.delete = AsyncMock(return_value=None)
|
||||||
|
|
||||||
|
client.close = AsyncMock(return_value=None)
|
||||||
|
|
||||||
|
return client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def milvus_index(mock_milvus_client):
|
||||||
|
"""Create a MilvusIndex with mocked client."""
|
||||||
|
index = MilvusIndex(client=mock_milvus_client, collection_name="test_collection")
|
||||||
|
yield index
|
||||||
|
# No real cleanup needed since we're using mocks
|
||||||
|
|
||||||
|
|
||||||
|
async def test_add_chunks(milvus_index, sample_chunks, sample_embeddings, mock_milvus_client):
|
||||||
|
# Setup: collection doesn't exist initially, then exists after creation
|
||||||
|
mock_milvus_client.list_collections.side_effect = [[], ["test_collection"]]
|
||||||
|
|
||||||
|
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
|
||||||
|
|
||||||
|
# Verify collection was created and data was inserted
|
||||||
|
mock_milvus_client.create_collection.assert_called_once()
|
||||||
|
mock_milvus_client.insert.assert_called_once()
|
||||||
|
|
||||||
|
# Verify the insert call had the right number of chunks
|
||||||
|
insert_call = mock_milvus_client.insert.call_args
|
||||||
|
assert len(insert_call[1]["data"]) == len(sample_chunks)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_query_chunks_vector(
|
||||||
|
milvus_index, sample_chunks, sample_embeddings, embedding_dimension, mock_milvus_client
|
||||||
|
):
|
||||||
|
# Setup: Add chunks first
|
||||||
|
mock_milvus_client.list_collections.return_value = ["test_collection"]
|
||||||
|
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
|
||||||
|
|
||||||
|
# Test vector search
|
||||||
|
query_embedding = np.random.rand(embedding_dimension).astype(np.float32)
|
||||||
|
response = await milvus_index.query_vector(query_embedding, k=2, score_threshold=0.0)
|
||||||
|
|
||||||
|
assert isinstance(response, QueryChunksResponse)
|
||||||
|
assert len(response.chunks) == 2
|
||||||
|
mock_milvus_client.search.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
async def test_query_chunks_keyword_search(milvus_index, sample_chunks, sample_embeddings, mock_milvus_client):
|
||||||
|
mock_milvus_client.list_collections.return_value = ["test_collection"]
|
||||||
|
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
|
||||||
|
|
||||||
|
# Test keyword search
|
||||||
|
query_string = "Sentence 5"
|
||||||
|
response = await milvus_index.query_keyword(query_string=query_string, k=2, score_threshold=0.0)
|
||||||
|
|
||||||
|
assert isinstance(response, QueryChunksResponse)
|
||||||
|
assert len(response.chunks) == 2
|
||||||
|
|
||||||
|
|
||||||
|
async def test_bm25_fallback_to_simple_search(milvus_index, sample_chunks, sample_embeddings, mock_milvus_client):
|
||||||
|
"""Test that when BM25 search fails, the system falls back to simple text search."""
|
||||||
|
mock_milvus_client.list_collections.return_value = ["test_collection"]
|
||||||
|
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
|
||||||
|
|
||||||
|
# Force BM25 search to fail
|
||||||
|
mock_milvus_client.search.side_effect = Exception("BM25 search not available")
|
||||||
|
|
||||||
|
# Mock simple text search results
|
||||||
|
mock_milvus_client.query.return_value = [
|
||||||
|
{
|
||||||
|
"chunk_id": "chunk1",
|
||||||
|
"chunk_content": {
|
||||||
|
"chunk_id": "chunk1",
|
||||||
|
"content": "Python programming language",
|
||||||
|
"metadata": {"document_id": "doc1"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"chunk_id": "chunk2",
|
||||||
|
"chunk_content": {
|
||||||
|
"chunk_id": "chunk2",
|
||||||
|
"content": "Machine learning algorithms",
|
||||||
|
"metadata": {"document_id": "doc2"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
# Test keyword search that should fall back to simple text search
|
||||||
|
query_string = "Python"
|
||||||
|
response = await milvus_index.query_keyword(query_string=query_string, k=3, score_threshold=0.0)
|
||||||
|
|
||||||
|
# Verify response structure
|
||||||
|
assert isinstance(response, QueryChunksResponse)
|
||||||
|
assert len(response.chunks) > 0, "Fallback search should return results"
|
||||||
|
|
||||||
|
# Verify that simple text search was used (query method called instead of search)
|
||||||
|
mock_milvus_client.query.assert_called_once()
|
||||||
|
mock_milvus_client.search.assert_called_once() # Called once but failed
|
||||||
|
|
||||||
|
# Verify the query uses parameterized filter with filter_params
|
||||||
|
query_call_args = mock_milvus_client.query.call_args
|
||||||
|
assert "filter" in query_call_args[1], "Query should include filter for text search"
|
||||||
|
assert "filter_params" in query_call_args[1], "Query should use parameterized filter"
|
||||||
|
assert query_call_args[1]["filter_params"]["content"] == "Python", "Filter params should contain the search term"
|
||||||
|
|
||||||
|
# Verify all returned chunks have score 1.0 (simple binary scoring)
|
||||||
|
assert all(score == 1.0 for score in response.scores), "Simple text search should use binary scoring"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_delete_collection(milvus_index, mock_milvus_client):
|
||||||
|
# Test collection deletion
|
||||||
|
mock_milvus_client.list_collections.return_value = ["test_collection"]
|
||||||
|
|
||||||
|
await milvus_index.delete()
|
||||||
|
|
||||||
|
mock_milvus_client.drop_collection.assert_called_once_with(collection_name=milvus_index.collection_name)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_query_hybrid_search_rrf(
|
||||||
|
milvus_index, sample_chunks, sample_embeddings, embedding_dimension, mock_milvus_client
|
||||||
|
):
|
||||||
|
"""Test hybrid search with RRF reranker."""
|
||||||
|
mock_milvus_client.list_collections.return_value = ["test_collection"]
|
||||||
|
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
|
||||||
|
|
||||||
|
# Mock hybrid search results
|
||||||
|
mock_milvus_client.hybrid_search.return_value = [
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"id": 0,
|
||||||
|
"distance": 0.1,
|
||||||
|
"entity": {
|
||||||
|
"chunk_content": {
|
||||||
|
"chunk_id": "chunk1",
|
||||||
|
"content": "mock chunk 1",
|
||||||
|
"metadata": {"document_id": "doc1"},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1,
|
||||||
|
"distance": 0.2,
|
||||||
|
"entity": {
|
||||||
|
"chunk_content": {
|
||||||
|
"chunk_id": "chunk2",
|
||||||
|
"content": "mock chunk 2",
|
||||||
|
"metadata": {"document_id": "doc2"},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
]
|
||||||
|
|
||||||
|
# Test hybrid search with RRF reranker
|
||||||
|
query_embedding = np.random.rand(embedding_dimension).astype(np.float32)
|
||||||
|
query_string = "test query"
|
||||||
|
response = await milvus_index.query_hybrid(
|
||||||
|
embedding=query_embedding,
|
||||||
|
query_string=query_string,
|
||||||
|
k=2,
|
||||||
|
score_threshold=0.0,
|
||||||
|
reranker_type="rrf",
|
||||||
|
reranker_params={"impact_factor": 60.0},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(response, QueryChunksResponse)
|
||||||
|
assert len(response.chunks) == 2
|
||||||
|
assert len(response.scores) == 2
|
||||||
|
|
||||||
|
# Verify hybrid search was called with correct parameters
|
||||||
|
mock_milvus_client.hybrid_search.assert_called_once()
|
||||||
|
call_args = mock_milvus_client.hybrid_search.call_args
|
||||||
|
|
||||||
|
# Check that the request contains both vector and BM25 search requests
|
||||||
|
reqs = call_args[1]["reqs"]
|
||||||
|
assert len(reqs) == 2
|
||||||
|
assert reqs[0].anns_field == "vector"
|
||||||
|
assert reqs[1].anns_field == "sparse"
|
||||||
|
ranker = call_args[1]["ranker"]
|
||||||
|
assert ranker is not None
|
||||||
|
|
||||||
|
|
||||||
|
async def test_query_hybrid_search_weighted(
|
||||||
|
milvus_index, sample_chunks, sample_embeddings, embedding_dimension, mock_milvus_client
|
||||||
|
):
|
||||||
|
"""Test hybrid search with weighted reranker."""
|
||||||
|
mock_milvus_client.list_collections.return_value = ["test_collection"]
|
||||||
|
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
|
||||||
|
|
||||||
|
# Mock hybrid search results
|
||||||
|
mock_milvus_client.hybrid_search.return_value = [
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"id": 0,
|
||||||
|
"distance": 0.1,
|
||||||
|
"entity": {
|
||||||
|
"chunk_content": {
|
||||||
|
"chunk_id": "chunk1",
|
||||||
|
"content": "mock chunk 1",
|
||||||
|
"metadata": {"document_id": "doc1"},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1,
|
||||||
|
"distance": 0.2,
|
||||||
|
"entity": {
|
||||||
|
"chunk_content": {
|
||||||
|
"chunk_id": "chunk2",
|
||||||
|
"content": "mock chunk 2",
|
||||||
|
"metadata": {"document_id": "doc2"},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
]
|
||||||
|
|
||||||
|
# Test hybrid search with weighted reranker
|
||||||
|
query_embedding = np.random.rand(embedding_dimension).astype(np.float32)
|
||||||
|
query_string = "test query"
|
||||||
|
response = await milvus_index.query_hybrid(
|
||||||
|
embedding=query_embedding,
|
||||||
|
query_string=query_string,
|
||||||
|
k=2,
|
||||||
|
score_threshold=0.0,
|
||||||
|
reranker_type="weighted",
|
||||||
|
reranker_params={"alpha": 0.7},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(response, QueryChunksResponse)
|
||||||
|
assert len(response.chunks) == 2
|
||||||
|
assert len(response.scores) == 2
|
||||||
|
|
||||||
|
# Verify hybrid search was called with correct parameters
|
||||||
|
mock_milvus_client.hybrid_search.assert_called_once()
|
||||||
|
call_args = mock_milvus_client.hybrid_search.call_args
|
||||||
|
ranker = call_args[1]["ranker"]
|
||||||
|
assert ranker is not None
|
||||||
|
|
||||||
|
|
||||||
|
async def test_query_hybrid_search_default_rrf(
|
||||||
|
milvus_index, sample_chunks, sample_embeddings, embedding_dimension, mock_milvus_client
|
||||||
|
):
|
||||||
|
"""Test hybrid search with default RRF reranker (no reranker_type specified)."""
|
||||||
|
mock_milvus_client.list_collections.return_value = ["test_collection"]
|
||||||
|
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
|
||||||
|
|
||||||
|
# Mock hybrid search results
|
||||||
|
mock_milvus_client.hybrid_search.return_value = [
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"id": 0,
|
||||||
|
"distance": 0.1,
|
||||||
|
"entity": {
|
||||||
|
"chunk_content": {
|
||||||
|
"chunk_id": "chunk1",
|
||||||
|
"content": "mock chunk 1",
|
||||||
|
"metadata": {"document_id": "doc1"},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
]
|
||||||
|
|
||||||
|
# Test hybrid search with default reranker (should be RRF)
|
||||||
|
query_embedding = np.random.rand(embedding_dimension).astype(np.float32)
|
||||||
|
query_string = "test query"
|
||||||
|
response = await milvus_index.query_hybrid(
|
||||||
|
embedding=query_embedding,
|
||||||
|
query_string=query_string,
|
||||||
|
k=1,
|
||||||
|
score_threshold=0.0,
|
||||||
|
reranker_type="unknown_type", # Should default to RRF
|
||||||
|
reranker_params=None, # Should use default impact_factor
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(response, QueryChunksResponse)
|
||||||
|
assert len(response.chunks) == 1
|
||||||
|
|
||||||
|
# Verify hybrid search was called with RRF reranker
|
||||||
|
mock_milvus_client.hybrid_search.assert_called_once()
|
||||||
|
call_args = mock_milvus_client.hybrid_search.call_args
|
||||||
|
ranker = call_args[1]["ranker"]
|
||||||
|
assert ranker is not None
|
||||||
|
|
@ -48,12 +48,12 @@ async def test_initialize_index(vector_index):
|
||||||
|
|
||||||
|
|
||||||
async def test_add_chunks_query_vector(vector_index, sample_chunks, sample_embeddings):
|
async def test_add_chunks_query_vector(vector_index, sample_chunks, sample_embeddings):
|
||||||
vector_index.delete()
|
await vector_index.delete()
|
||||||
vector_index.initialize()
|
await vector_index.initialize()
|
||||||
await vector_index.add_chunks(sample_chunks, sample_embeddings)
|
await vector_index.add_chunks(sample_chunks, sample_embeddings)
|
||||||
resp = await vector_index.query_vector(sample_embeddings[0], k=1, score_threshold=-1)
|
resp = await vector_index.query_vector(sample_embeddings[0], k=1, score_threshold=-1)
|
||||||
assert resp.chunks[0].content == sample_chunks[0].content
|
assert resp.chunks[0].content == sample_chunks[0].content
|
||||||
vector_index.delete()
|
await vector_index.delete()
|
||||||
|
|
||||||
|
|
||||||
async def test_chunk_id_conflict(vector_index, sample_chunks, embedding_dimension):
|
async def test_chunk_id_conflict(vector_index, sample_chunks, embedding_dimension):
|
||||||
|
|
|
||||||
6
uv.lock
generated
6
uv.lock
generated
|
|
@ -1,5 +1,5 @@
|
||||||
version = 1
|
version = 1
|
||||||
revision = 2
|
revision = 3
|
||||||
requires-python = ">=3.12"
|
requires-python = ">=3.12"
|
||||||
resolution-markers = [
|
resolution-markers = [
|
||||||
"(python_full_version >= '3.13' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version >= '3.13' and sys_platform != 'darwin' and sys_platform != 'linux')",
|
"(python_full_version >= '3.13' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version >= '3.13' and sys_platform != 'darwin' and sys_platform != 'linux')",
|
||||||
|
|
@ -2133,9 +2133,11 @@ unit = [
|
||||||
{ name = "faiss-cpu" },
|
{ name = "faiss-cpu" },
|
||||||
{ name = "litellm" },
|
{ name = "litellm" },
|
||||||
{ name = "mcp" },
|
{ name = "mcp" },
|
||||||
|
{ name = "milvus-lite" },
|
||||||
{ name = "moto", extra = ["s3"] },
|
{ name = "moto", extra = ["s3"] },
|
||||||
{ name = "ollama" },
|
{ name = "ollama" },
|
||||||
{ name = "psycopg2-binary" },
|
{ name = "psycopg2-binary" },
|
||||||
|
{ name = "pymilvus" },
|
||||||
{ name = "pypdf" },
|
{ name = "pypdf" },
|
||||||
{ name = "sqlalchemy", extra = ["asyncio"] },
|
{ name = "sqlalchemy", extra = ["asyncio"] },
|
||||||
{ name = "sqlite-vec" },
|
{ name = "sqlite-vec" },
|
||||||
|
|
@ -2277,9 +2279,11 @@ unit = [
|
||||||
{ name = "faiss-cpu" },
|
{ name = "faiss-cpu" },
|
||||||
{ name = "litellm" },
|
{ name = "litellm" },
|
||||||
{ name = "mcp" },
|
{ name = "mcp" },
|
||||||
|
{ name = "milvus-lite", specifier = ">=2.5.0" },
|
||||||
{ name = "moto", extras = ["s3"], specifier = ">=5.1.10" },
|
{ name = "moto", extras = ["s3"], specifier = ">=5.1.10" },
|
||||||
{ name = "ollama" },
|
{ name = "ollama" },
|
||||||
{ name = "psycopg2-binary", specifier = ">=2.9.0" },
|
{ name = "psycopg2-binary", specifier = ">=2.9.0" },
|
||||||
|
{ name = "pymilvus", specifier = ">=2.6.1" },
|
||||||
{ name = "pypdf", specifier = ">=6.1.3" },
|
{ name = "pypdf", specifier = ">=6.1.3" },
|
||||||
{ name = "sqlalchemy" },
|
{ name = "sqlalchemy" },
|
||||||
{ name = "sqlalchemy", extras = ["asyncio"], specifier = ">=2.0.41" },
|
{ name = "sqlalchemy", extras = ["asyncio"], specifier = ">=2.0.41" },
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue