mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
feat: migrate from MilvusClient to AsyncMilvusClient
<!-- Provide a short summary of what this PR does and why. Link to relevant issues if applicable. --> This PR adds static type coverage to `llama-stack` Part of https://github.com/meta-llama/llama-stack/issues/2647 <!-- If resolving an issue, uncomment and update the line below --> <!-- Closes #[issue-number] --> <!-- Describe the tests you ran to verify your changes with result summaries. *Provide clear instructions so the plan can be easily re-executed.* --> Signed-off-by: Mustafa Elbehery <melbeher@redhat.com>
This commit is contained in:
parent
81109a0f72
commit
f75ec332b5
2 changed files with 103 additions and 107 deletions
|
@ -4,7 +4,6 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
@ -12,7 +11,7 @@ import re
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from numpy.typing import NDArray
|
from numpy.typing import NDArray
|
||||||
from pymilvus import DataType, MilvusClient
|
from pymilvus import AsyncMilvusClient, DataType
|
||||||
|
|
||||||
from llama_stack.apis.files.files import Files
|
from llama_stack.apis.files.files import Files
|
||||||
from llama_stack.apis.inference import Inference, InterleavedContent
|
from llama_stack.apis.inference import Inference, InterleavedContent
|
||||||
|
@ -54,7 +53,11 @@ def sanitize_collection_name(name: str) -> str:
|
||||||
|
|
||||||
class MilvusIndex(EmbeddingIndex):
|
class MilvusIndex(EmbeddingIndex):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, client: MilvusClient, collection_name: str, consistency_level="Strong", kvstore: KVStore | None = None
|
self,
|
||||||
|
client: AsyncMilvusClient,
|
||||||
|
collection_name: str,
|
||||||
|
consistency_level="Strong",
|
||||||
|
kvstore: KVStore | None = None,
|
||||||
):
|
):
|
||||||
self.client = client
|
self.client = client
|
||||||
self.collection_name = sanitize_collection_name(collection_name)
|
self.collection_name = sanitize_collection_name(collection_name)
|
||||||
|
@ -62,16 +65,15 @@ class MilvusIndex(EmbeddingIndex):
|
||||||
self.kvstore = kvstore
|
self.kvstore = kvstore
|
||||||
|
|
||||||
async def delete(self):
|
async def delete(self):
|
||||||
if await asyncio.to_thread(self.client.has_collection, self.collection_name):
|
if await self.client.collection_exists(self.collection_name):
|
||||||
await asyncio.to_thread(self.client.drop_collection, collection_name=self.collection_name)
|
await self.client.drop_collection(collection_name=self.collection_name)
|
||||||
|
|
||||||
async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray):
|
async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray):
|
||||||
assert len(chunks) == len(embeddings), (
|
assert len(chunks) == len(embeddings), (
|
||||||
f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}"
|
f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}"
|
||||||
)
|
)
|
||||||
if not await asyncio.to_thread(self.client.has_collection, self.collection_name):
|
if not await self.client.collection_exists(self.collection_name):
|
||||||
await asyncio.to_thread(
|
await self.client.create_collection(
|
||||||
self.client.create_collection,
|
|
||||||
self.collection_name,
|
self.collection_name,
|
||||||
dimension=len(embeddings[0]),
|
dimension=len(embeddings[0]),
|
||||||
auto_id=True,
|
auto_id=True,
|
||||||
|
@ -88,8 +90,7 @@ class MilvusIndex(EmbeddingIndex):
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
await asyncio.to_thread(
|
await self.client.insert(
|
||||||
self.client.insert,
|
|
||||||
self.collection_name,
|
self.collection_name,
|
||||||
data=data,
|
data=data,
|
||||||
)
|
)
|
||||||
|
@ -98,8 +99,7 @@ class MilvusIndex(EmbeddingIndex):
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||||
search_res = await asyncio.to_thread(
|
search_res = await self.client.search(
|
||||||
self.client.search,
|
|
||||||
collection_name=self.collection_name,
|
collection_name=self.collection_name,
|
||||||
data=[embedding],
|
data=[embedding],
|
||||||
limit=k,
|
limit=k,
|
||||||
|
@ -139,16 +139,28 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
||||||
) -> None:
|
) -> None:
|
||||||
self.config = config
|
self.config = config
|
||||||
self.cache = {}
|
self.cache = {}
|
||||||
self.client = None
|
self.client: AsyncMilvusClient | None = None
|
||||||
self.inference_api = inference_api
|
self.inference_api = inference_api
|
||||||
self.files_api = files_api
|
self.files_api = files_api
|
||||||
self.kvstore: KVStore | None = None
|
self.kvstore: KVStore | None = None
|
||||||
self.vector_db_store = None
|
|
||||||
self.openai_vector_stores: dict[str, dict[str, Any]] = {}
|
self.openai_vector_stores: dict[str, dict[str, Any]] = {}
|
||||||
self.metadata_collection_name = "openai_vector_stores_metadata"
|
self.metadata_collection_name = "openai_vector_stores_metadata"
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
|
if self.config.kvstore is not None:
|
||||||
self.kvstore = await kvstore_impl(self.config.kvstore)
|
self.kvstore = await kvstore_impl(self.config.kvstore)
|
||||||
|
|
||||||
|
# Initialize client first before using it
|
||||||
|
if isinstance(self.config, RemoteMilvusVectorIOConfig):
|
||||||
|
logger.info(f"Connecting to Milvus server at {self.config.uri}")
|
||||||
|
self.client = AsyncMilvusClient(**self.config.model_dump(exclude_none=True))
|
||||||
|
else:
|
||||||
|
logger.info(f"Connecting to Milvus Lite at: {self.config.db_path}")
|
||||||
|
uri = os.path.expanduser(self.config.db_path)
|
||||||
|
self.client = AsyncMilvusClient(uri=uri)
|
||||||
|
|
||||||
|
# Now load stored vector databases
|
||||||
|
if self.kvstore is not None:
|
||||||
start_key = VECTOR_DBS_PREFIX
|
start_key = VECTOR_DBS_PREFIX
|
||||||
end_key = f"{VECTOR_DBS_PREFIX}\xff"
|
end_key = f"{VECTOR_DBS_PREFIX}\xff"
|
||||||
stored_vector_dbs = await self.kvstore.values_in_range(start_key, end_key)
|
stored_vector_dbs = await self.kvstore.values_in_range(start_key, end_key)
|
||||||
|
@ -166,30 +178,29 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
||||||
inference_api=self.inference_api,
|
inference_api=self.inference_api,
|
||||||
)
|
)
|
||||||
self.cache[vector_db.identifier] = index
|
self.cache[vector_db.identifier] = index
|
||||||
if isinstance(self.config, RemoteMilvusVectorIOConfig):
|
|
||||||
logger.info(f"Connecting to Milvus server at {self.config.uri}")
|
|
||||||
self.client = MilvusClient(**self.config.model_dump(exclude_none=True))
|
|
||||||
else:
|
|
||||||
logger.info(f"Connecting to Milvus Lite at: {self.config.db_path}")
|
|
||||||
uri = os.path.expanduser(self.config.db_path)
|
|
||||||
self.client = MilvusClient(uri=uri)
|
|
||||||
|
|
||||||
self.openai_vector_stores = await self._load_openai_vector_stores()
|
self.openai_vector_stores = await self._load_openai_vector_stores()
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
self.client.close()
|
if self.client is not None:
|
||||||
|
await self.client.close()
|
||||||
|
|
||||||
async def register_vector_db(
|
def _ensure_client_initialized(self) -> AsyncMilvusClient:
|
||||||
self,
|
"""Ensure the client is initialized and return it."""
|
||||||
vector_db: VectorDB,
|
if self.client is None:
|
||||||
) -> None:
|
raise RuntimeError("Milvus client is not initialized. Call initialize() first.")
|
||||||
|
return self.client
|
||||||
|
|
||||||
|
async def register_vector_db(self, vector_db: VectorDB) -> None:
|
||||||
if isinstance(self.config, RemoteMilvusVectorIOConfig):
|
if isinstance(self.config, RemoteMilvusVectorIOConfig):
|
||||||
consistency_level = self.config.consistency_level
|
consistency_level = self.config.consistency_level
|
||||||
else:
|
else:
|
||||||
consistency_level = "Strong"
|
consistency_level = "Strong"
|
||||||
index = VectorDBWithIndex(
|
index = VectorDBWithIndex(
|
||||||
vector_db=vector_db,
|
vector_db=vector_db,
|
||||||
index=MilvusIndex(self.client, vector_db.identifier, consistency_level=consistency_level),
|
index=MilvusIndex(
|
||||||
|
self._ensure_client_initialized(), vector_db.identifier, consistency_level=consistency_level
|
||||||
|
),
|
||||||
inference_api=self.inference_api,
|
inference_api=self.inference_api,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -199,17 +210,8 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
||||||
if vector_db_id in self.cache:
|
if vector_db_id in self.cache:
|
||||||
return self.cache[vector_db_id]
|
return self.cache[vector_db_id]
|
||||||
|
|
||||||
vector_db = await self.vector_db_store.get_vector_db(vector_db_id)
|
# Vector DB should be registered before use
|
||||||
if not vector_db:
|
raise ValueError(f"Vector DB {vector_db_id} not found. Please register it first.")
|
||||||
raise ValueError(f"Vector DB {vector_db_id} not found")
|
|
||||||
|
|
||||||
index = VectorDBWithIndex(
|
|
||||||
vector_db=vector_db,
|
|
||||||
index=MilvusIndex(client=self.client, collection_name=vector_db.identifier, kvstore=self.kvstore),
|
|
||||||
inference_api=self.inference_api,
|
|
||||||
)
|
|
||||||
self.cache[vector_db_id] = index
|
|
||||||
return index
|
|
||||||
|
|
||||||
async def unregister_vector_db(self, vector_db_id: str) -> None:
|
async def unregister_vector_db(self, vector_db_id: str) -> None:
|
||||||
if vector_db_id in self.cache:
|
if vector_db_id in self.cache:
|
||||||
|
@ -275,14 +277,17 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Save vector store file metadata to Milvus database."""
|
"""Save vector store file metadata to Milvus database."""
|
||||||
if store_id not in self.openai_vector_stores:
|
if store_id not in self.openai_vector_stores:
|
||||||
store_info = await self._load_openai_vector_stores(store_id)
|
# Reload all vector stores to check if the store exists
|
||||||
if not store_info:
|
self.openai_vector_stores = await self._load_openai_vector_stores()
|
||||||
|
if store_id not in self.openai_vector_stores:
|
||||||
logger.error(f"OpenAI vector store {store_id} not found")
|
logger.error(f"OpenAI vector store {store_id} not found")
|
||||||
raise ValueError(f"No vector store found with id {store_id}")
|
raise ValueError(f"No vector store found with id {store_id}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if not await asyncio.to_thread(self.client.has_collection, "openai_vector_store_files"):
|
client = self._ensure_client_initialized()
|
||||||
file_schema = MilvusClient.create_schema(
|
|
||||||
|
if not await client.collection_exists("openai_vector_store_files"):
|
||||||
|
file_schema = AsyncMilvusClient.create_schema(
|
||||||
auto_id=False,
|
auto_id=False,
|
||||||
enable_dynamic_field=True,
|
enable_dynamic_field=True,
|
||||||
description="Metadata for OpenAI vector store files",
|
description="Metadata for OpenAI vector store files",
|
||||||
|
@ -294,14 +299,13 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
||||||
file_schema.add_field(field_name="file_id", datatype=DataType.VARCHAR, max_length=512)
|
file_schema.add_field(field_name="file_id", datatype=DataType.VARCHAR, max_length=512)
|
||||||
file_schema.add_field(field_name="file_info", datatype=DataType.VARCHAR, max_length=65535)
|
file_schema.add_field(field_name="file_info", datatype=DataType.VARCHAR, max_length=65535)
|
||||||
|
|
||||||
await asyncio.to_thread(
|
await client.create_collection(
|
||||||
self.client.create_collection,
|
|
||||||
collection_name="openai_vector_store_files",
|
collection_name="openai_vector_store_files",
|
||||||
schema=file_schema,
|
schema=file_schema,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not await asyncio.to_thread(self.client.has_collection, "openai_vector_store_files_contents"):
|
if not await client.collection_exists("openai_vector_store_files_contents"):
|
||||||
content_schema = MilvusClient.create_schema(
|
content_schema = AsyncMilvusClient.create_schema(
|
||||||
auto_id=False,
|
auto_id=False,
|
||||||
enable_dynamic_field=True,
|
enable_dynamic_field=True,
|
||||||
description="Contents for OpenAI vector store files",
|
description="Contents for OpenAI vector store files",
|
||||||
|
@ -314,8 +318,7 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
||||||
content_schema.add_field(field_name="file_id", datatype=DataType.VARCHAR, max_length=512)
|
content_schema.add_field(field_name="file_id", datatype=DataType.VARCHAR, max_length=512)
|
||||||
content_schema.add_field(field_name="content", datatype=DataType.VARCHAR, max_length=65535)
|
content_schema.add_field(field_name="content", datatype=DataType.VARCHAR, max_length=65535)
|
||||||
|
|
||||||
await asyncio.to_thread(
|
await client.create_collection(
|
||||||
self.client.create_collection,
|
|
||||||
collection_name="openai_vector_store_files_contents",
|
collection_name="openai_vector_store_files_contents",
|
||||||
schema=content_schema,
|
schema=content_schema,
|
||||||
)
|
)
|
||||||
|
@ -328,44 +331,44 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
||||||
"file_info": json.dumps(file_info),
|
"file_info": json.dumps(file_info),
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
await asyncio.to_thread(
|
await client.upsert(
|
||||||
self.client.upsert,
|
|
||||||
collection_name="openai_vector_store_files",
|
collection_name="openai_vector_store_files",
|
||||||
data=file_data,
|
data=file_data,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Save file contents
|
# Save file contents
|
||||||
contents_data = [
|
contents_data = []
|
||||||
|
for content in file_contents:
|
||||||
|
chunk_metadata = content.get("chunk_metadata")
|
||||||
|
if chunk_metadata and chunk_metadata.get("chunk_id"):
|
||||||
|
contents_data.append(
|
||||||
{
|
{
|
||||||
"chunk_id": content.get("chunk_metadata").get("chunk_id"),
|
"chunk_id": chunk_metadata.get("chunk_id"),
|
||||||
"store_file_id": f"{store_id}_{file_id}",
|
"store_file_id": f"{store_id}_{file_id}",
|
||||||
"store_id": store_id,
|
"store_id": store_id,
|
||||||
"file_id": file_id,
|
"file_id": file_id,
|
||||||
"content": json.dumps(content),
|
"content": json.dumps(content),
|
||||||
}
|
}
|
||||||
for content in file_contents
|
|
||||||
]
|
|
||||||
await asyncio.to_thread(
|
|
||||||
self.client.upsert,
|
|
||||||
collection_name="openai_vector_store_files_contents",
|
|
||||||
data=contents_data,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if contents_data:
|
||||||
|
await client.upsert(collection_name="openai_vector_store_files_contents", data=contents_data)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error saving openai vector store file {file_id} for store {store_id}: {e}")
|
logger.error(f"Error saving openai vector store file {file_id} for store {store_id}: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
async def _load_openai_vector_store_file(self, store_id: str, file_id: str) -> dict[str, Any]:
|
async def _load_openai_vector_store_file(self, store_id: str, file_id: str) -> dict[str, Any]:
|
||||||
"""Load vector store file metadata from Milvus database."""
|
"""Load vector store file metadata from Milvus database."""
|
||||||
try:
|
try:
|
||||||
if not await asyncio.to_thread(self.client.has_collection, "openai_vector_store_files"):
|
client = self._ensure_client_initialized()
|
||||||
|
|
||||||
|
if not await client.collection_exists("openai_vector_store_files"):
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
query_filter = f"store_file_id == '{store_id}_{file_id}'"
|
query_filter = f"store_file_id == '{store_id}_{file_id}'"
|
||||||
results = await asyncio.to_thread(
|
results = await client.query(
|
||||||
self.client.query,
|
collection_name="openai_vector_store_files", filter=query_filter, output_fields=["file_info"]
|
||||||
collection_name="openai_vector_store_files",
|
|
||||||
filter=query_filter,
|
|
||||||
output_fields=["file_info"],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if results:
|
if results:
|
||||||
|
@ -382,7 +385,9 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
||||||
async def _update_openai_vector_store_file(self, store_id: str, file_id: str, file_info: dict[str, Any]) -> None:
|
async def _update_openai_vector_store_file(self, store_id: str, file_id: str, file_info: dict[str, Any]) -> None:
|
||||||
"""Update vector store file metadata in Milvus database."""
|
"""Update vector store file metadata in Milvus database."""
|
||||||
try:
|
try:
|
||||||
if not await asyncio.to_thread(self.client.has_collection, "openai_vector_store_files"):
|
client = self._ensure_client_initialized()
|
||||||
|
|
||||||
|
if not await client.collection_exists("openai_vector_store_files"):
|
||||||
return
|
return
|
||||||
|
|
||||||
file_data = [
|
file_data = [
|
||||||
|
@ -393,11 +398,7 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
||||||
"file_info": json.dumps(file_info),
|
"file_info": json.dumps(file_info),
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
await asyncio.to_thread(
|
await client.upsert(collection_name="openai_vector_store_files", data=file_data)
|
||||||
self.client.upsert,
|
|
||||||
collection_name="openai_vector_store_files",
|
|
||||||
data=file_data,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error updating openai vector store file {file_id} for store {store_id}: {e}")
|
logger.error(f"Error updating openai vector store file {file_id} for store {store_id}: {e}")
|
||||||
raise
|
raise
|
||||||
|
@ -405,14 +406,15 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
||||||
async def _load_openai_vector_store_file_contents(self, store_id: str, file_id: str) -> list[dict[str, Any]]:
|
async def _load_openai_vector_store_file_contents(self, store_id: str, file_id: str) -> list[dict[str, Any]]:
|
||||||
"""Load vector store file contents from Milvus database."""
|
"""Load vector store file contents from Milvus database."""
|
||||||
try:
|
try:
|
||||||
if not await asyncio.to_thread(self.client.has_collection, "openai_vector_store_files_contents"):
|
client = self._ensure_client_initialized()
|
||||||
|
|
||||||
|
if not await client.collection_exists("openai_vector_store_files_contents"):
|
||||||
return []
|
return []
|
||||||
|
|
||||||
query_filter = (
|
query_filter = (
|
||||||
f"store_id == '{store_id}' AND file_id == '{file_id}' AND store_file_id == '{store_id}_{file_id}'"
|
f"store_id == '{store_id}' AND file_id == '{file_id}' AND store_file_id == '{store_id}_{file_id}'"
|
||||||
)
|
)
|
||||||
results = await asyncio.to_thread(
|
results = await client.query(
|
||||||
self.client.query,
|
|
||||||
collection_name="openai_vector_store_files_contents",
|
collection_name="openai_vector_store_files_contents",
|
||||||
filter=query_filter,
|
filter=query_filter,
|
||||||
output_fields=["chunk_id", "store_id", "file_id", "content"],
|
output_fields=["chunk_id", "store_id", "file_id", "content"],
|
||||||
|
@ -433,21 +435,15 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
||||||
async def _delete_openai_vector_store_file_from_storage(self, store_id: str, file_id: str) -> None:
|
async def _delete_openai_vector_store_file_from_storage(self, store_id: str, file_id: str) -> None:
|
||||||
"""Delete vector store file metadata from Milvus database."""
|
"""Delete vector store file metadata from Milvus database."""
|
||||||
try:
|
try:
|
||||||
if not await asyncio.to_thread(self.client.has_collection, "openai_vector_store_files"):
|
client = self._ensure_client_initialized()
|
||||||
|
|
||||||
|
if not await client.collection_exists("openai_vector_store_files"):
|
||||||
return
|
return
|
||||||
|
|
||||||
query_filter = f"store_file_id in ['{store_id}_{file_id}']"
|
query_filter = f"store_file_id in ['{store_id}_{file_id}']"
|
||||||
await asyncio.to_thread(
|
await client.delete(collection_name="openai_vector_store_files", filter=query_filter)
|
||||||
self.client.delete,
|
if await client.collection_exists("openai_vector_store_files_contents"):
|
||||||
collection_name="openai_vector_store_files",
|
await client.delete(collection_name="openai_vector_store_files_contents", filter=query_filter)
|
||||||
filter=query_filter,
|
|
||||||
)
|
|
||||||
if await asyncio.to_thread(self.client.has_collection, "openai_vector_store_files_contents"):
|
|
||||||
await asyncio.to_thread(
|
|
||||||
self.client.delete,
|
|
||||||
collection_name="openai_vector_store_files_contents",
|
|
||||||
filter=query_filter,
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error deleting openai vector store file {file_id} for store {store_id}: {e}")
|
logger.error(f"Error deleting openai vector store file {file_id} for store {store_id}: {e}")
|
||||||
|
|
|
@ -11,7 +11,7 @@ from unittest.mock import AsyncMock
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
import pytest_asyncio
|
import pytest_asyncio
|
||||||
from pymilvus import Collection, MilvusClient, connections
|
from pymilvus import AsyncMilvusClient, Collection, connections
|
||||||
|
|
||||||
from llama_stack.apis.vector_dbs import VectorDB
|
from llama_stack.apis.vector_dbs import VectorDB
|
||||||
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse
|
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse
|
||||||
|
@ -52,7 +52,7 @@ async def unique_kvstore_config(tmp_path_factory):
|
||||||
async def milvus_vec_index(embedding_dimension, tmp_path_factory):
|
async def milvus_vec_index(embedding_dimension, tmp_path_factory):
|
||||||
temp_dir = tmp_path_factory.getbasetemp()
|
temp_dir = tmp_path_factory.getbasetemp()
|
||||||
db_path = str(temp_dir / "test_milvus.db")
|
db_path = str(temp_dir / "test_milvus.db")
|
||||||
client = MilvusClient(db_path)
|
client = AsyncMilvusClient(uri=db_path)
|
||||||
name = f"{COLLECTION_PREFIX}_{np.random.randint(1e6)}"
|
name = f"{COLLECTION_PREFIX}_{np.random.randint(1e6)}"
|
||||||
connections.connect(alias=MILVUS_ALIAS, uri=db_path)
|
connections.connect(alias=MILVUS_ALIAS, uri=db_path)
|
||||||
index = MilvusIndex(client, name, consistency_level="Strong")
|
index = MilvusIndex(client, name, consistency_level="Strong")
|
||||||
|
@ -148,7 +148,7 @@ async def test_initialize_with_milvus_client(milvus_vec_index, unique_kvstore_co
|
||||||
await tmp_milvus_vec_adapter.kvstore.set(f"{VECTOR_DBS_PREFIX}/test_db", test_vector_db_data)
|
await tmp_milvus_vec_adapter.kvstore.set(f"{VECTOR_DBS_PREFIX}/test_db", test_vector_db_data)
|
||||||
|
|
||||||
assert milvus_vec_index.client is not None
|
assert milvus_vec_index.client is not None
|
||||||
assert isinstance(milvus_vec_index.client, MilvusClient)
|
assert isinstance(milvus_vec_index.client, AsyncMilvusClient)
|
||||||
assert tmp_milvus_vec_adapter.cache is not None
|
assert tmp_milvus_vec_adapter.cache is not None
|
||||||
# registering a vector won't update the cache or openai_vector_store collection name
|
# registering a vector won't update the cache or openai_vector_store collection name
|
||||||
assert (
|
assert (
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue