This commit is contained in:
Mustafa Elbehery 2025-10-03 14:11:23 +02:00 committed by GitHub
commit cfe5ac498f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 151 additions and 86 deletions

View file

@ -4,12 +4,11 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import os
from typing import Any
from numpy.typing import NDArray
from pymilvus import AnnSearchRequest, DataType, Function, FunctionType, MilvusClient, RRFRanker, WeightedRanker
from pymilvus import AnnSearchRequest, AsyncMilvusClient, DataType, Function, FunctionType, RRFRanker, WeightedRanker
from llama_stack.apis.common.errors import VectorStoreNotFoundError
from llama_stack.apis.files.files import Files
@ -48,12 +47,18 @@ OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX = f"openai_vector_stores_files_conten
class MilvusIndex(EmbeddingIndex):
def __init__(
self, client: MilvusClient, collection_name: str, consistency_level="Strong", kvstore: KVStore | None = None
self,
client: AsyncMilvusClient,
collection_name: str,
consistency_level="Strong",
kvstore: KVStore | None = None,
parent_adapter=None,
):
self.client = client
self.collection_name = sanitize_collection_name(collection_name)
self.consistency_level = consistency_level
self.kvstore = kvstore
self._parent_adapter = parent_adapter
async def initialize(self):
# MilvusIndex does not require explicit initialization
@ -61,15 +66,39 @@ class MilvusIndex(EmbeddingIndex):
pass
async def delete(self):
if await asyncio.to_thread(self.client.has_collection, self.collection_name):
await asyncio.to_thread(self.client.drop_collection, collection_name=self.collection_name)
try:
collections = await self.client.list_collections()
if self.collection_name in collections:
await self.client.drop_collection(collection_name=self.collection_name)
except Exception as e:
logger.warning(f"Failed to check or delete collection {self.collection_name}: {e}")
async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray):
assert len(chunks) == len(embeddings), (
f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}"
)
if not await asyncio.to_thread(self.client.has_collection, self.collection_name):
try:
collections = await self.client.list_collections()
collection_exists = self.collection_name in collections
except Exception as e:
logger.error(f"Failed to check collection existence: {self.collection_name} ({e})")
# If it's an event loop issue, try to recreate the client
if "attached to a different loop" in str(e):
logger.warning("Recreating client due to event loop issue")
if hasattr(self, "_parent_adapter"):
await self._parent_adapter._recreate_client()
collections = await self.client.list_collections()
collection_exists = self.collection_name in collections
else:
# Assume collection doesn't exist if we can't check
collection_exists = False
else:
# Assume collection doesn't exist if we can't check due to other issues
collection_exists = False
if not collection_exists:
logger.info(f"Creating new collection {self.collection_name} with nullable sparse field")
# Create schema for vector search
schema = self.client.create_schema()
@ -123,13 +152,16 @@ class MilvusIndex(EmbeddingIndex):
)
schema.add_function(bm25_function)
await asyncio.to_thread(
self.client.create_collection,
self.collection_name,
schema=schema,
index_params=index_params,
consistency_level=self.consistency_level,
)
try:
await self.client.create_collection(
self.collection_name,
schema=schema,
index_params=index_params,
consistency_level=self.consistency_level,
)
except Exception as e:
logger.error(f"Failed to create collection {self.collection_name}: {e}")
raise e
data = []
for chunk, embedding in zip(chunks, embeddings, strict=False):
@ -143,8 +175,7 @@ class MilvusIndex(EmbeddingIndex):
}
)
try:
await asyncio.to_thread(
self.client.insert,
await self.client.insert(
self.collection_name,
data=data,
)
@ -153,8 +184,7 @@ class MilvusIndex(EmbeddingIndex):
raise e
async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
search_res = await asyncio.to_thread(
self.client.search,
search_res = await self.client.search(
collection_name=self.collection_name,
data=[embedding],
anns_field="vector",
@ -177,8 +207,7 @@ class MilvusIndex(EmbeddingIndex):
"""
try:
# Use Milvus's built-in BM25 search
search_res = await asyncio.to_thread(
self.client.search,
search_res = await self.client.search(
collection_name=self.collection_name,
data=[query_string], # Raw text query
anns_field="sparse", # Use sparse field for BM25
@ -219,8 +248,7 @@ class MilvusIndex(EmbeddingIndex):
Fallback to simple text search when BM25 search is not available.
"""
# Simple text search using content field
search_res = await asyncio.to_thread(
self.client.query,
search_res = await self.client.query(
collection_name=self.collection_name,
filter='content like "%{content}%"',
filter_params={"content": query_string},
@ -267,8 +295,7 @@ class MilvusIndex(EmbeddingIndex):
impact_factor = (reranker_params or {}).get("impact_factor", 60.0)
rerank = RRFRanker(impact_factor)
search_res = await asyncio.to_thread(
self.client.hybrid_search,
search_res = await self.client.hybrid_search(
collection_name=self.collection_name,
reqs=search_requests,
ranker=rerank,
@ -294,9 +321,7 @@ class MilvusIndex(EmbeddingIndex):
try:
# Use IN clause with square brackets and single quotes for VARCHAR field
chunk_ids_str = ", ".join(f"'{chunk_id}'" for chunk_id in chunk_ids)
await asyncio.to_thread(
self.client.delete, collection_name=self.collection_name, filter=f"chunk_id in [{chunk_ids_str}]"
)
await self.client.delete(collection_name=self.collection_name, filter=f"chunk_id in [{chunk_ids_str}]")
except Exception as e:
logger.error(f"Error deleting chunks from Milvus collection {self.collection_name}: {e}")
raise
@ -321,6 +346,15 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
async def initialize(self) -> None:
self.kvstore = await kvstore_impl(self.config.kvstore)
if isinstance(self.config, RemoteMilvusVectorIOConfig):
logger.info(f"Connecting to Milvus server at {self.config.uri}")
self.client = AsyncMilvusClient(**self.config.model_dump(exclude_none=True))
else:
logger.info(f"Connecting to Milvus Lite at: {self.config.db_path}")
uri = os.path.expanduser(self.config.db_path)
self.client = AsyncMilvusClient(uri=uri)
start_key = VECTOR_DBS_PREFIX
end_key = f"{VECTOR_DBS_PREFIX}\xff"
stored_vector_dbs = await self.kvstore.values_in_range(start_key, end_key)
@ -334,23 +368,38 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
collection_name=vector_db.identifier,
consistency_level=self.config.consistency_level,
kvstore=self.kvstore,
parent_adapter=self,
),
inference_api=self.inference_api,
)
self.cache[vector_db.identifier] = index
if isinstance(self.config, RemoteMilvusVectorIOConfig):
logger.info(f"Connecting to Milvus server at {self.config.uri}")
self.client = MilvusClient(**self.config.model_dump(exclude_none=True))
else:
logger.info(f"Connecting to Milvus Lite at: {self.config.db_path}")
uri = os.path.expanduser(self.config.db_path)
self.client = MilvusClient(uri=uri)
# Load existing OpenAI vector stores into the in-memory cache
await self.initialize_openai_vector_stores()
async def shutdown(self) -> None:
self.client.close()
if self.client:
await self.client.close()
async def _recreate_client(self) -> None:
"""Recreate the AsyncMilvusClient when event loop issues occur"""
try:
if self.client:
await self.client.close()
except Exception as e:
logger.warning(f"Error closing old client: {e}")
if isinstance(self.config, RemoteMilvusVectorIOConfig):
logger.info(f"Recreating connection to Milvus server at {self.config.uri}")
self.client = AsyncMilvusClient(**self.config.model_dump(exclude_none=True))
else:
logger.info(f"Recreating connection to Milvus Lite at: {self.config.db_path}")
uri = os.path.expanduser(self.config.db_path)
self.client = AsyncMilvusClient(uri=uri)
for index_wrapper in self.cache.values():
if hasattr(index_wrapper, "index") and hasattr(index_wrapper.index, "client"):
index_wrapper.index.client = self.client
async def register_vector_db(
self,
@ -362,7 +411,12 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
consistency_level = "Strong"
index = VectorDBWithIndex(
vector_db=vector_db,
index=MilvusIndex(self.client, vector_db.identifier, consistency_level=consistency_level),
index=MilvusIndex(
client=self.client,
collection_name=vector_db.identifier,
consistency_level=consistency_level,
parent_adapter=self,
),
inference_api=self.inference_api,
)
@ -381,7 +435,9 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
index = VectorDBWithIndex(
vector_db=vector_db,
index=MilvusIndex(client=self.client, collection_name=vector_db.identifier, kvstore=self.kvstore),
index=MilvusIndex(
client=self.client, collection_name=vector_db.identifier, kvstore=self.kvstore, parent_adapter=self
),
inference_api=self.inference_api,
)
self.cache[vector_db_id] = index