mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 12:07:34 +00:00
fix(integration): init AsyncMilvusClient before MilvusIndex
Signed-off-by: Mustafa Elbehery <melbeher@redhat.com>
This commit is contained in:
parent
e7444c1d9b
commit
733d0c70fe
2 changed files with 77 additions and 22 deletions
|
@ -52,11 +52,13 @@ class MilvusIndex(EmbeddingIndex):
|
||||||
collection_name: str,
|
collection_name: str,
|
||||||
consistency_level="Strong",
|
consistency_level="Strong",
|
||||||
kvstore: KVStore | None = None,
|
kvstore: KVStore | None = None,
|
||||||
|
parent_adapter=None,
|
||||||
):
|
):
|
||||||
self.client = client
|
self.client = client
|
||||||
self.collection_name = sanitize_collection_name(collection_name)
|
self.collection_name = sanitize_collection_name(collection_name)
|
||||||
self.consistency_level = consistency_level
|
self.consistency_level = consistency_level
|
||||||
self.kvstore = kvstore
|
self.kvstore = kvstore
|
||||||
|
self._parent_adapter = parent_adapter
|
||||||
|
|
||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
# MilvusIndex does not require explicit initialization
|
# MilvusIndex does not require explicit initialization
|
||||||
|
@ -64,15 +66,36 @@ class MilvusIndex(EmbeddingIndex):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def delete(self):
|
async def delete(self):
|
||||||
|
try:
|
||||||
if await self.client.has_collection(self.collection_name):
|
if await self.client.has_collection(self.collection_name):
|
||||||
await self.client.drop_collection(collection_name=self.collection_name)
|
await self.client.drop_collection(collection_name=self.collection_name)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to check or delete collection {self.collection_name}: {e}")
|
||||||
|
|
||||||
async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray):
|
async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray):
|
||||||
assert len(chunks) == len(embeddings), (
|
assert len(chunks) == len(embeddings), (
|
||||||
f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}"
|
f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
if not await self.client.has_collection(self.collection_name):
|
try:
|
||||||
|
collection_exists = await self.client.has_collection(self.collection_name)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to check collection existence: {self.collection_name} ({e})")
|
||||||
|
# If it's an event loop issue, try to recreate the client
|
||||||
|
if "attached to a different loop" in str(e):
|
||||||
|
logger.warning("Recreating client due to event loop issue")
|
||||||
|
|
||||||
|
if hasattr(self, "_parent_adapter"):
|
||||||
|
await self._parent_adapter._recreate_client()
|
||||||
|
collection_exists = await self.client.has_collection(self.collection_name)
|
||||||
|
else:
|
||||||
|
# Assume collection doesn't exist if we can't check
|
||||||
|
collection_exists = False
|
||||||
|
else:
|
||||||
|
# Assume collection doesn't exist if we can't check due to other issues
|
||||||
|
collection_exists = False
|
||||||
|
|
||||||
|
if not collection_exists:
|
||||||
logger.info(f"Creating new collection {self.collection_name} with nullable sparse field")
|
logger.info(f"Creating new collection {self.collection_name} with nullable sparse field")
|
||||||
# Create schema for vector search
|
# Create schema for vector search
|
||||||
schema = self.client.create_schema()
|
schema = self.client.create_schema()
|
||||||
|
@ -126,12 +149,16 @@ class MilvusIndex(EmbeddingIndex):
|
||||||
)
|
)
|
||||||
schema.add_function(bm25_function)
|
schema.add_function(bm25_function)
|
||||||
|
|
||||||
|
try:
|
||||||
await self.client.create_collection(
|
await self.client.create_collection(
|
||||||
self.collection_name,
|
self.collection_name,
|
||||||
schema=schema,
|
schema=schema,
|
||||||
index_params=index_params,
|
index_params=index_params,
|
||||||
consistency_level=self.consistency_level,
|
consistency_level=self.consistency_level,
|
||||||
)
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to create collection {self.collection_name}: {e}")
|
||||||
|
raise e
|
||||||
|
|
||||||
data = []
|
data = []
|
||||||
for chunk, embedding in zip(chunks, embeddings, strict=False):
|
for chunk, embedding in zip(chunks, embeddings, strict=False):
|
||||||
|
@ -316,6 +343,15 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
self.kvstore = await kvstore_impl(self.config.kvstore)
|
self.kvstore = await kvstore_impl(self.config.kvstore)
|
||||||
|
|
||||||
|
if isinstance(self.config, RemoteMilvusVectorIOConfig):
|
||||||
|
logger.info(f"Connecting to Milvus server at {self.config.uri}")
|
||||||
|
self.client = AsyncMilvusClient(**self.config.model_dump(exclude_none=True))
|
||||||
|
else:
|
||||||
|
logger.info(f"Connecting to Milvus Lite at: {self.config.db_path}")
|
||||||
|
uri = os.path.expanduser(self.config.db_path)
|
||||||
|
self.client = AsyncMilvusClient(uri=uri)
|
||||||
|
|
||||||
start_key = VECTOR_DBS_PREFIX
|
start_key = VECTOR_DBS_PREFIX
|
||||||
end_key = f"{VECTOR_DBS_PREFIX}\xff"
|
end_key = f"{VECTOR_DBS_PREFIX}\xff"
|
||||||
stored_vector_dbs = await self.kvstore.values_in_range(start_key, end_key)
|
stored_vector_dbs = await self.kvstore.values_in_range(start_key, end_key)
|
||||||
|
@ -329,24 +365,39 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
||||||
collection_name=vector_db.identifier,
|
collection_name=vector_db.identifier,
|
||||||
consistency_level=self.config.consistency_level,
|
consistency_level=self.config.consistency_level,
|
||||||
kvstore=self.kvstore,
|
kvstore=self.kvstore,
|
||||||
|
parent_adapter=self,
|
||||||
),
|
),
|
||||||
inference_api=self.inference_api,
|
inference_api=self.inference_api,
|
||||||
)
|
)
|
||||||
self.cache[vector_db.identifier] = index
|
self.cache[vector_db.identifier] = index
|
||||||
if isinstance(self.config, RemoteMilvusVectorIOConfig):
|
|
||||||
logger.info(f"Connecting to Milvus server at {self.config.uri}")
|
|
||||||
self.client = AsyncMilvusClient(**self.config.model_dump(exclude_none=True))
|
|
||||||
else:
|
|
||||||
logger.info(f"Connecting to Milvus Lite at: {self.config.db_path}")
|
|
||||||
uri = os.path.expanduser(self.config.db_path)
|
|
||||||
self.client = AsyncMilvusClient(uri=uri)
|
|
||||||
|
|
||||||
# Load existing OpenAI vector stores into the in-memory cache
|
# Load existing OpenAI vector stores into the in-memory cache
|
||||||
await self.initialize_openai_vector_stores()
|
await self.initialize_openai_vector_stores()
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
|
if self.client:
|
||||||
await self.client.close()
|
await self.client.close()
|
||||||
|
|
||||||
|
async def _recreate_client(self) -> None:
|
||||||
|
"""Recreate the AsyncMilvusClient when event loop issues occur"""
|
||||||
|
try:
|
||||||
|
if self.client:
|
||||||
|
await self.client.close()
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error closing old client: {e}")
|
||||||
|
|
||||||
|
if isinstance(self.config, RemoteMilvusVectorIOConfig):
|
||||||
|
logger.info(f"Recreating connection to Milvus server at {self.config.uri}")
|
||||||
|
self.client = AsyncMilvusClient(**self.config.model_dump(exclude_none=True))
|
||||||
|
else:
|
||||||
|
logger.info(f"Recreating connection to Milvus Lite at: {self.config.db_path}")
|
||||||
|
uri = os.path.expanduser(self.config.db_path)
|
||||||
|
self.client = AsyncMilvusClient(uri=uri)
|
||||||
|
|
||||||
|
for index_wrapper in self.cache.values():
|
||||||
|
if hasattr(index_wrapper, "index") and hasattr(index_wrapper.index, "client"):
|
||||||
|
index_wrapper.index.client = self.client
|
||||||
|
|
||||||
async def register_vector_db(
|
async def register_vector_db(
|
||||||
self,
|
self,
|
||||||
vector_db: VectorDB,
|
vector_db: VectorDB,
|
||||||
|
@ -357,7 +408,12 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
||||||
consistency_level = "Strong"
|
consistency_level = "Strong"
|
||||||
index = VectorDBWithIndex(
|
index = VectorDBWithIndex(
|
||||||
vector_db=vector_db,
|
vector_db=vector_db,
|
||||||
index=MilvusIndex(self.client, vector_db.identifier, consistency_level=consistency_level),
|
index=MilvusIndex(
|
||||||
|
client=self.client,
|
||||||
|
collection_name=vector_db.identifier,
|
||||||
|
consistency_level=consistency_level,
|
||||||
|
parent_adapter=self,
|
||||||
|
),
|
||||||
inference_api=self.inference_api,
|
inference_api=self.inference_api,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -376,7 +432,9 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
||||||
|
|
||||||
index = VectorDBWithIndex(
|
index = VectorDBWithIndex(
|
||||||
vector_db=vector_db,
|
vector_db=vector_db,
|
||||||
index=MilvusIndex(client=self.client, collection_name=vector_db.identifier, kvstore=self.kvstore),
|
index=MilvusIndex(
|
||||||
|
client=self.client, collection_name=vector_db.identifier, kvstore=self.kvstore, parent_adapter=self
|
||||||
|
),
|
||||||
inference_api=self.inference_api,
|
inference_api=self.inference_api,
|
||||||
)
|
)
|
||||||
self.cache[vector_db_id] = index
|
self.cache[vector_db_id] = index
|
||||||
|
|
|
@ -40,15 +40,12 @@ async def mock_milvus_client() -> MagicMock:
|
||||||
"""Create a mock Milvus client with common method behaviors."""
|
"""Create a mock Milvus client with common method behaviors."""
|
||||||
client = MagicMock()
|
client = MagicMock()
|
||||||
|
|
||||||
# Mock async collection operations
|
|
||||||
client.has_collection = AsyncMock(return_value=False) # Initially no collection
|
client.has_collection = AsyncMock(return_value=False) # Initially no collection
|
||||||
client.create_collection = AsyncMock(return_value=None)
|
client.create_collection = AsyncMock(return_value=None)
|
||||||
client.drop_collection = AsyncMock(return_value=None)
|
client.drop_collection = AsyncMock(return_value=None)
|
||||||
|
|
||||||
# Mock async insert operation
|
|
||||||
client.insert = AsyncMock(return_value={"insert_count": 10})
|
client.insert = AsyncMock(return_value={"insert_count": 10})
|
||||||
|
|
||||||
# Mock async search operation
|
|
||||||
client.search = AsyncMock(
|
client.search = AsyncMock(
|
||||||
return_value=[
|
return_value=[
|
||||||
[
|
[
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue