mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 07:14:20 +00:00
Merge 4e566276a5
into 09abdb0a37
This commit is contained in:
commit
a2988cac4b
4 changed files with 480 additions and 35 deletions
|
@ -114,7 +114,7 @@ For more details on TLS configuration, refer to the [TLS setup guide](https://mi
|
||||||
| `uri` | `<class 'str'>` | No | PydanticUndefined | The URI of the Milvus server |
|
| `uri` | `<class 'str'>` | No | PydanticUndefined | The URI of the Milvus server |
|
||||||
| `token` | `str \| None` | No | PydanticUndefined | The token of the Milvus server |
|
| `token` | `str \| None` | No | PydanticUndefined | The token of the Milvus server |
|
||||||
| `consistency_level` | `<class 'str'>` | No | Strong | The consistency level of the Milvus server |
|
| `consistency_level` | `<class 'str'>` | No | Strong | The consistency level of the Milvus server |
|
||||||
| `kvstore` | `utils.kvstore.config.RedisKVStoreConfig \| utils.kvstore.config.SqliteKVStoreConfig \| utils.kvstore.config.PostgresKVStoreConfig \| utils.kvstore.config.MongoDBKVStoreConfig` | No | sqlite | Config for KV store backend |
|
| `kvstore` | `utils.kvstore.config.RedisKVStoreConfig \| utils.kvstore.config.SqliteKVStoreConfig \| utils.kvstore.config.PostgresKVStoreConfig \| utils.kvstore.config.MongoDBKVStoreConfig, annotation=NoneType, required=False, default='sqlite', discriminator='type'` | No | | Config for KV store backend (SQLite only for now). Optional for remote Milvus connections - only needed for vector database registry persistence across server restarts. |
|
||||||
| `config` | `dict` | No | {} | This configuration allows additional fields to be passed through to the underlying Milvus client. See the [Milvus](https://milvus.io/docs/install-overview.md) documentation for more details about Milvus in general. |
|
| `config` | `dict` | No | {} | This configuration allows additional fields to be passed through to the underlying Milvus client. See the [Milvus](https://milvus.io/docs/install-overview.md) documentation for more details about Milvus in general. |
|
||||||
|
|
||||||
> **Note**: This configuration class accepts additional fields beyond those listed above. You can pass any additional configuration options that will be forwarded to the underlying provider.
|
> **Note**: This configuration class accepts additional fields beyond those listed above. You can pass any additional configuration options that will be forwarded to the underlying provider.
|
||||||
|
|
|
@ -17,7 +17,10 @@ class MilvusVectorIOConfig(BaseModel):
|
||||||
uri: str = Field(description="The URI of the Milvus server")
|
uri: str = Field(description="The URI of the Milvus server")
|
||||||
token: str | None = Field(description="The token of the Milvus server")
|
token: str | None = Field(description="The token of the Milvus server")
|
||||||
consistency_level: str = Field(description="The consistency level of the Milvus server", default="Strong")
|
consistency_level: str = Field(description="The consistency level of the Milvus server", default="Strong")
|
||||||
kvstore: KVStoreConfig = Field(description="Config for KV store backend")
|
kvstore: KVStoreConfig | None = Field(
|
||||||
|
description="Config for KV store backend (SQLite only for now). Optional for remote Milvus connections - only needed for vector database registry persistence across server restarts.",
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
|
|
||||||
# This configuration allows additional fields to be passed through to the underlying Milvus client.
|
# This configuration allows additional fields to be passed through to the underlying Milvus client.
|
||||||
# See the [Milvus](https://milvus.io/docs/install-overview.md) documentation for more details about Milvus in general.
|
# See the [Milvus](https://milvus.io/docs/install-overview.md) documentation for more details about Milvus in general.
|
||||||
|
|
|
@ -5,13 +5,20 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from numpy.typing import NDArray
|
from numpy.typing import NDArray
|
||||||
from pymilvus import DataType, Function, FunctionType, MilvusClient
|
from pymilvus import DataType, MilvusClient
|
||||||
|
# Function and FunctionType are not available in all pymilvus versions
|
||||||
|
try:
|
||||||
|
from pymilvus import Function, FunctionType
|
||||||
|
except ImportError:
|
||||||
|
Function = None
|
||||||
|
FunctionType = None
|
||||||
|
|
||||||
from llama_stack.apis.files.files import Files
|
from llama_stack.apis.files.files import Files
|
||||||
from llama_stack.apis.inference import Inference, InterleavedContent
|
from llama_stack.apis.inference import Inference, InterleavedContent
|
||||||
|
@ -276,10 +283,32 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
||||||
self.metadata_collection_name = "openai_vector_stores_metadata"
|
self.metadata_collection_name = "openai_vector_stores_metadata"
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
self.kvstore = await kvstore_impl(self.config.kvstore)
|
# MilvusVectorIOAdapter is used for both inline and remote connections
|
||||||
start_key = VECTOR_DBS_PREFIX
|
if isinstance(self.config, RemoteMilvusVectorIOConfig):
|
||||||
end_key = f"{VECTOR_DBS_PREFIX}\xff"
|
# Remote Milvus: kvstore is optional for registry persistence across server restarts
|
||||||
stored_vector_dbs = await self.kvstore.values_in_range(start_key, end_key)
|
if self.config.kvstore is not None:
|
||||||
|
self.kvstore = await kvstore_impl(self.config.kvstore)
|
||||||
|
logger.info("Remote Milvus: Using kvstore for vector database registry persistence")
|
||||||
|
else:
|
||||||
|
self.kvstore = None
|
||||||
|
logger.info("Remote Milvus: No kvstore configured, registry will not persist across restarts")
|
||||||
|
if self.kvstore is not None:
|
||||||
|
start_key = VECTOR_DBS_PREFIX
|
||||||
|
end_key = f"{VECTOR_DBS_PREFIX}\xff"
|
||||||
|
stored_vector_dbs = await self.kvstore.values_in_range(start_key, end_key)
|
||||||
|
else:
|
||||||
|
stored_vector_dbs = []
|
||||||
|
|
||||||
|
elif isinstance(self.config, InlineMilvusVectorIOConfig):
|
||||||
|
self.kvstore = await kvstore_impl(self.config.kvstore)
|
||||||
|
logger.info("Inline Milvus: Using kvstore for local vector database registry")
|
||||||
|
start_key = VECTOR_DBS_PREFIX
|
||||||
|
end_key = f"{VECTOR_DBS_PREFIX}\xff"
|
||||||
|
stored_vector_dbs = await self.kvstore.values_in_range(start_key, end_key)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported config type: {type(self.config)}. Expected RemoteMilvusVectorIOConfig or InlineMilvusVectorIOConfig"
|
||||||
|
)
|
||||||
|
|
||||||
for vector_db_data in stored_vector_dbs:
|
for vector_db_data in stored_vector_dbs:
|
||||||
vector_db = VectorDB.model_validate_json(vector_db_data)
|
vector_db = VectorDB.model_validate_json(vector_db_data)
|
||||||
|
@ -295,12 +324,16 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
||||||
)
|
)
|
||||||
self.cache[vector_db.identifier] = index
|
self.cache[vector_db.identifier] = index
|
||||||
if isinstance(self.config, RemoteMilvusVectorIOConfig):
|
if isinstance(self.config, RemoteMilvusVectorIOConfig):
|
||||||
logger.info(f"Connecting to Milvus server at {self.config.uri}")
|
logger.info(f"Connecting to remote Milvus server at {self.config.uri}")
|
||||||
self.client = MilvusClient(**self.config.model_dump(exclude_none=True))
|
self.client = MilvusClient(**self.config.model_dump(exclude_none=True))
|
||||||
else:
|
elif isinstance(self.config, InlineMilvusVectorIOConfig):
|
||||||
logger.info(f"Connecting to Milvus Lite at: {self.config.db_path}")
|
logger.info(f"Connecting to local Milvus Lite at: {self.config.db_path}")
|
||||||
uri = os.path.expanduser(self.config.db_path)
|
uri = os.path.expanduser(self.config.db_path)
|
||||||
self.client = MilvusClient(uri=uri)
|
self.client = MilvusClient(uri=uri)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported config type: {type(self.config)}. Expected RemoteMilvusVectorIOConfig or InlineMilvusVectorIOConfig"
|
||||||
|
)
|
||||||
|
|
||||||
# Load existing OpenAI vector stores into the in-memory cache
|
# Load existing OpenAI vector stores into the in-memory cache
|
||||||
await self.initialize_openai_vector_stores()
|
await self.initialize_openai_vector_stores()
|
||||||
|
@ -314,8 +347,12 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
||||||
) -> None:
|
) -> None:
|
||||||
if isinstance(self.config, RemoteMilvusVectorIOConfig):
|
if isinstance(self.config, RemoteMilvusVectorIOConfig):
|
||||||
consistency_level = self.config.consistency_level
|
consistency_level = self.config.consistency_level
|
||||||
|
elif isinstance(self.config, InlineMilvusVectorIOConfig):
|
||||||
|
consistency_level = self.config.consistency_level
|
||||||
else:
|
else:
|
||||||
consistency_level = "Strong"
|
raise ValueError(
|
||||||
|
f"Unsupported config type: {type(self.config)}. Expected RemoteMilvusVectorIOConfig or InlineMilvusVectorIOConfig"
|
||||||
|
)
|
||||||
index = VectorDBWithIndex(
|
index = VectorDBWithIndex(
|
||||||
vector_db=vector_db,
|
vector_db=vector_db,
|
||||||
index=MilvusIndex(self.client, vector_db.identifier, consistency_level=consistency_level),
|
index=MilvusIndex(self.client, vector_db.identifier, consistency_level=consistency_level),
|
||||||
|
@ -389,3 +426,217 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
||||||
for chunk_id in chunk_ids:
|
for chunk_id in chunk_ids:
|
||||||
# Use the index's delete_chunk method
|
# Use the index's delete_chunk method
|
||||||
await index.index.delete_chunk(chunk_id)
|
await index.index.delete_chunk(chunk_id)
|
||||||
|
|
||||||
|
async def _save_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
|
||||||
|
"""Save vector store metadata to persistent storage."""
|
||||||
|
if self.kvstore is not None:
|
||||||
|
key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}"
|
||||||
|
await self.kvstore.set(key=key, value=json.dumps(store_info))
|
||||||
|
self.openai_vector_stores[store_id] = store_info
|
||||||
|
|
||||||
|
async def _update_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
|
||||||
|
"""Update vector store metadata in persistent storage."""
|
||||||
|
if self.kvstore is not None:
|
||||||
|
key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}"
|
||||||
|
await self.kvstore.set(key=key, value=json.dumps(store_info))
|
||||||
|
self.openai_vector_stores[store_id] = store_info
|
||||||
|
|
||||||
|
async def _delete_openai_vector_store_from_storage(self, store_id: str) -> None:
|
||||||
|
"""Delete vector store metadata from persistent storage."""
|
||||||
|
if self.kvstore is not None:
|
||||||
|
key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}"
|
||||||
|
await self.kvstore.delete(key)
|
||||||
|
if store_id in self.openai_vector_stores:
|
||||||
|
del self.openai_vector_stores[store_id]
|
||||||
|
|
||||||
|
async def _load_openai_vector_stores(self) -> dict[str, dict[str, Any]]:
|
||||||
|
"""Load all vector store metadata from persistent storage."""
|
||||||
|
if self.kvstore is None:
|
||||||
|
return {}
|
||||||
|
start_key = OPENAI_VECTOR_STORES_PREFIX
|
||||||
|
end_key = f"{OPENAI_VECTOR_STORES_PREFIX}\xff"
|
||||||
|
stored = await self.kvstore.values_in_range(start_key, end_key)
|
||||||
|
return {json.loads(s)["id"]: json.loads(s) for s in stored}
|
||||||
|
|
||||||
|
async def _save_openai_vector_store_file(
|
||||||
|
self, store_id: str, file_id: str, file_info: dict[str, Any], file_contents: list[dict[str, Any]]
|
||||||
|
) -> None:
|
||||||
|
"""Save vector store file metadata to Milvus database."""
|
||||||
|
if store_id not in self.openai_vector_stores:
|
||||||
|
store_info = await self._load_openai_vector_stores(store_id)
|
||||||
|
if not store_info:
|
||||||
|
logger.error(f"OpenAI vector store {store_id} not found")
|
||||||
|
raise ValueError(f"No vector store found with id {store_id}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
if not await asyncio.to_thread(self.client.has_collection, "openai_vector_store_files"):
|
||||||
|
file_schema = MilvusClient.create_schema(
|
||||||
|
auto_id=False,
|
||||||
|
enable_dynamic_field=True,
|
||||||
|
description="Metadata for OpenAI vector store files",
|
||||||
|
)
|
||||||
|
file_schema.add_field(
|
||||||
|
field_name="store_file_id", datatype=DataType.VARCHAR, is_primary=True, max_length=512
|
||||||
|
)
|
||||||
|
file_schema.add_field(field_name="store_id", datatype=DataType.VARCHAR, max_length=512)
|
||||||
|
file_schema.add_field(field_name="file_id", datatype=DataType.VARCHAR, max_length=512)
|
||||||
|
file_schema.add_field(field_name="file_info", datatype=DataType.VARCHAR, max_length=65535)
|
||||||
|
|
||||||
|
await asyncio.to_thread(
|
||||||
|
self.client.create_collection,
|
||||||
|
collection_name="openai_vector_store_files",
|
||||||
|
schema=file_schema,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not await asyncio.to_thread(self.client.has_collection, "openai_vector_store_files_contents"):
|
||||||
|
content_schema = MilvusClient.create_schema(
|
||||||
|
auto_id=False,
|
||||||
|
enable_dynamic_field=True,
|
||||||
|
description="Contents for OpenAI vector store files",
|
||||||
|
)
|
||||||
|
content_schema.add_field(
|
||||||
|
field_name="chunk_id", datatype=DataType.VARCHAR, is_primary=True, max_length=1024
|
||||||
|
)
|
||||||
|
content_schema.add_field(field_name="store_file_id", datatype=DataType.VARCHAR, max_length=1024)
|
||||||
|
content_schema.add_field(field_name="store_id", datatype=DataType.VARCHAR, max_length=512)
|
||||||
|
content_schema.add_field(field_name="file_id", datatype=DataType.VARCHAR, max_length=512)
|
||||||
|
content_schema.add_field(field_name="content", datatype=DataType.VARCHAR, max_length=65535)
|
||||||
|
|
||||||
|
await asyncio.to_thread(
|
||||||
|
self.client.create_collection,
|
||||||
|
collection_name="openai_vector_store_files_contents",
|
||||||
|
schema=content_schema,
|
||||||
|
)
|
||||||
|
|
||||||
|
file_data = [
|
||||||
|
{
|
||||||
|
"store_file_id": f"{store_id}_{file_id}",
|
||||||
|
"store_id": store_id,
|
||||||
|
"file_id": file_id,
|
||||||
|
"file_info": json.dumps(file_info),
|
||||||
|
}
|
||||||
|
]
|
||||||
|
await asyncio.to_thread(
|
||||||
|
self.client.upsert,
|
||||||
|
collection_name="openai_vector_store_files",
|
||||||
|
data=file_data,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Save file contents
|
||||||
|
contents_data = [
|
||||||
|
{
|
||||||
|
"chunk_id": content.get("chunk_metadata").get("chunk_id"),
|
||||||
|
"store_file_id": f"{store_id}_{file_id}",
|
||||||
|
"store_id": store_id,
|
||||||
|
"file_id": file_id,
|
||||||
|
"content": json.dumps(content),
|
||||||
|
}
|
||||||
|
for content in file_contents
|
||||||
|
]
|
||||||
|
await asyncio.to_thread(
|
||||||
|
self.client.upsert,
|
||||||
|
collection_name="openai_vector_store_files_contents",
|
||||||
|
data=contents_data,
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error saving openai vector store file {file_id} for store {store_id}: {e}")
|
||||||
|
|
||||||
|
async def _load_openai_vector_store_file(self, store_id: str, file_id: str) -> dict[str, Any]:
|
||||||
|
"""Load vector store file metadata from Milvus database."""
|
||||||
|
try:
|
||||||
|
if not await asyncio.to_thread(self.client.has_collection, "openai_vector_store_files"):
|
||||||
|
return {}
|
||||||
|
|
||||||
|
query_filter = f"store_file_id == '{store_id}_{file_id}'"
|
||||||
|
results = await asyncio.to_thread(
|
||||||
|
self.client.query,
|
||||||
|
collection_name="openai_vector_store_files",
|
||||||
|
filter=query_filter,
|
||||||
|
output_fields=["file_info"],
|
||||||
|
)
|
||||||
|
|
||||||
|
if results:
|
||||||
|
try:
|
||||||
|
return json.loads(results[0]["file_info"])
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
logger.error(f"Failed to decode file_info for store {store_id}, file {file_id}: {e}")
|
||||||
|
return {}
|
||||||
|
return {}
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error loading openai vector store file {file_id} for store {store_id}: {e}")
|
||||||
|
return {}
|
||||||
|
|
||||||
|
async def _update_openai_vector_store_file(self, store_id: str, file_id: str, file_info: dict[str, Any]) -> None:
|
||||||
|
"""Update vector store file metadata in Milvus database."""
|
||||||
|
try:
|
||||||
|
if not await asyncio.to_thread(self.client.has_collection, "openai_vector_store_files"):
|
||||||
|
return
|
||||||
|
|
||||||
|
file_data = [
|
||||||
|
{
|
||||||
|
"store_file_id": f"{store_id}_{file_id}",
|
||||||
|
"store_id": store_id,
|
||||||
|
"file_id": file_id,
|
||||||
|
"file_info": json.dumps(file_info),
|
||||||
|
}
|
||||||
|
]
|
||||||
|
await asyncio.to_thread(
|
||||||
|
self.client.upsert,
|
||||||
|
collection_name="openai_vector_store_files",
|
||||||
|
data=file_data,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error updating openai vector store file {file_id} for store {store_id}: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def _load_openai_vector_store_file_contents(self, store_id: str, file_id: str) -> list[dict[str, Any]]:
|
||||||
|
"""Load vector store file contents from Milvus database."""
|
||||||
|
try:
|
||||||
|
if not await asyncio.to_thread(self.client.has_collection, "openai_vector_store_files_contents"):
|
||||||
|
return []
|
||||||
|
|
||||||
|
query_filter = (
|
||||||
|
f"store_id == '{store_id}' AND file_id == '{file_id}' AND store_file_id == '{store_id}_{file_id}'"
|
||||||
|
)
|
||||||
|
results = await asyncio.to_thread(
|
||||||
|
self.client.query,
|
||||||
|
collection_name="openai_vector_store_files_contents",
|
||||||
|
filter=query_filter,
|
||||||
|
output_fields=["chunk_id", "store_id", "file_id", "content"],
|
||||||
|
)
|
||||||
|
|
||||||
|
contents = []
|
||||||
|
for result in results:
|
||||||
|
try:
|
||||||
|
content = json.loads(result["content"])
|
||||||
|
contents.append(content)
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
logger.error(f"Failed to decode content for store {store_id}, file {file_id}: {e}")
|
||||||
|
return contents
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error loading openai vector store file contents for {file_id} in store {store_id}: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def _delete_openai_vector_store_file_from_storage(self, store_id: str, file_id: str) -> None:
|
||||||
|
"""Delete vector store file metadata from Milvus database."""
|
||||||
|
try:
|
||||||
|
if not await asyncio.to_thread(self.client.has_collection, "openai_vector_store_files"):
|
||||||
|
return
|
||||||
|
|
||||||
|
query_filter = f"store_file_id in ['{store_id}_{file_id}']"
|
||||||
|
await asyncio.to_thread(
|
||||||
|
self.client.delete,
|
||||||
|
collection_name="openai_vector_store_files",
|
||||||
|
filter=query_filter,
|
||||||
|
)
|
||||||
|
if await asyncio.to_thread(self.client.has_collection, "openai_vector_store_files_contents"):
|
||||||
|
await asyncio.to_thread(
|
||||||
|
self.client.delete,
|
||||||
|
collection_name="openai_vector_store_files_contents",
|
||||||
|
filter=query_filter,
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error deleting openai vector store file {file_id} for store {store_id}: {e}")
|
||||||
|
raise
|
||||||
|
|
|
@ -4,6 +4,7 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import asyncio
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -18,7 +19,8 @@ pymilvus_mock.MilvusClient = MagicMock
|
||||||
|
|
||||||
# Apply the mock before importing MilvusIndex
|
# Apply the mock before importing MilvusIndex
|
||||||
with patch.dict("sys.modules", {"pymilvus": pymilvus_mock}):
|
with patch.dict("sys.modules", {"pymilvus": pymilvus_mock}):
|
||||||
from llama_stack.providers.remote.vector_io.milvus.milvus import MilvusIndex
|
from llama_stack.providers.remote.vector_io.milvus.milvus import MilvusIndex, MilvusVectorIOAdapter
|
||||||
|
from llama_stack.providers.remote.vector_io.milvus.config import MilvusVectorIOConfig as RemoteMilvusVectorIOConfig
|
||||||
|
|
||||||
# This test is a unit test for the MilvusVectorIOAdapter class. This should only contain
|
# This test is a unit test for the MilvusVectorIOAdapter class. This should only contain
|
||||||
# tests which are specific to this class. More general (API-level) tests should be placed in
|
# tests which are specific to this class. More general (API-level) tests should be placed in
|
||||||
|
@ -91,6 +93,40 @@ async def milvus_index(mock_milvus_client):
|
||||||
# No real cleanup needed since we're using mocks
|
# No real cleanup needed since we're using mocks
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def mock_inference_api():
|
||||||
|
"""Create a mock inference API."""
|
||||||
|
api = MagicMock()
|
||||||
|
api.embed.return_value = np.array([[0.1, 0.2, 0.3]])
|
||||||
|
return api
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def remote_milvus_config_with_kvstore():
|
||||||
|
"""Create a remote Milvus config with kvstore."""
|
||||||
|
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||||
|
|
||||||
|
config = RemoteMilvusVectorIOConfig(
|
||||||
|
uri="http://localhost:19530",
|
||||||
|
token=None,
|
||||||
|
consistency_level="Strong",
|
||||||
|
kvstore=SqliteKVStoreConfig(db_path="/tmp/test.db"), # Use proper kvstore config
|
||||||
|
)
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def remote_milvus_config_without_kvstore():
|
||||||
|
"""Create a remote Milvus config without kvstore (None)."""
|
||||||
|
config = RemoteMilvusVectorIOConfig(
|
||||||
|
uri="http://localhost:19530",
|
||||||
|
token=None,
|
||||||
|
consistency_level="Strong",
|
||||||
|
kvstore=None, # No kvstore
|
||||||
|
)
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
async def test_add_chunks(milvus_index, sample_chunks, sample_embeddings, mock_milvus_client):
|
async def test_add_chunks(milvus_index, sample_chunks, sample_embeddings, mock_milvus_client):
|
||||||
# Setup: collection doesn't exist initially, then exists after creation
|
# Setup: collection doesn't exist initially, then exists after creation
|
||||||
mock_milvus_client.has_collection.side_effect = [False, True]
|
mock_milvus_client.has_collection.side_effect = [False, True]
|
||||||
|
@ -101,8 +137,9 @@ async def test_add_chunks(milvus_index, sample_chunks, sample_embeddings, mock_m
|
||||||
mock_milvus_client.create_collection.assert_called_once()
|
mock_milvus_client.create_collection.assert_called_once()
|
||||||
mock_milvus_client.insert.assert_called_once()
|
mock_milvus_client.insert.assert_called_once()
|
||||||
|
|
||||||
# Verify the insert call had the right number of chunks
|
# Verify the data format in the insert call
|
||||||
insert_call = mock_milvus_client.insert.call_args
|
insert_call = mock_milvus_client.insert.call_args
|
||||||
|
assert insert_call[1]["collection_name"] == "test_collection"
|
||||||
assert len(insert_call[1]["data"]) == len(sample_chunks)
|
assert len(insert_call[1]["data"]) == len(sample_chunks)
|
||||||
|
|
||||||
|
|
||||||
|
@ -113,67 +150,71 @@ async def test_query_chunks_vector(
|
||||||
mock_milvus_client.has_collection.return_value = True
|
mock_milvus_client.has_collection.return_value = True
|
||||||
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
|
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
|
||||||
|
|
||||||
# Test vector search
|
# Query with a test embedding
|
||||||
query_embedding = np.random.rand(embedding_dimension).astype(np.float32)
|
query_embedding = np.random.rand(embedding_dimension)
|
||||||
response = await milvus_index.query_vector(query_embedding, k=2, score_threshold=0.0)
|
response = await milvus_index.query_vector(query_embedding, k=2, score_threshold=0.0)
|
||||||
|
|
||||||
|
# Verify search was called and response is valid
|
||||||
|
mock_milvus_client.search.assert_called_once()
|
||||||
assert isinstance(response, QueryChunksResponse)
|
assert isinstance(response, QueryChunksResponse)
|
||||||
assert len(response.chunks) == 2
|
assert len(response.chunks) == 2
|
||||||
mock_milvus_client.search.assert_called_once()
|
|
||||||
|
|
||||||
|
|
||||||
async def test_query_chunks_keyword_search(milvus_index, sample_chunks, sample_embeddings, mock_milvus_client):
|
async def test_query_chunks_keyword_search(milvus_index, sample_chunks, sample_embeddings, mock_milvus_client):
|
||||||
|
# Setup: Add chunks first
|
||||||
mock_milvus_client.has_collection.return_value = True
|
mock_milvus_client.has_collection.return_value = True
|
||||||
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
|
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
|
||||||
|
|
||||||
# Test keyword search
|
# Test keyword search
|
||||||
query_string = "Sentence 5"
|
query_string = "test query"
|
||||||
response = await milvus_index.query_keyword(query_string=query_string, k=2, score_threshold=0.0)
|
response = await milvus_index.query_keyword(query_string=query_string, k=2, score_threshold=0.0)
|
||||||
|
|
||||||
|
# Verify search was called and response is valid
|
||||||
|
mock_milvus_client.search.assert_called_once()
|
||||||
assert isinstance(response, QueryChunksResponse)
|
assert isinstance(response, QueryChunksResponse)
|
||||||
assert len(response.chunks) == 2
|
assert len(response.chunks) == 2
|
||||||
|
|
||||||
|
|
||||||
async def test_bm25_fallback_to_simple_search(milvus_index, sample_chunks, sample_embeddings, mock_milvus_client):
|
async def test_bm25_fallback_to_simple_search(milvus_index, sample_chunks, sample_embeddings, mock_milvus_client):
|
||||||
"""Test that when BM25 search fails, the system falls back to simple text search."""
|
# Setup: Add chunks first
|
||||||
mock_milvus_client.has_collection.return_value = True
|
mock_milvus_client.has_collection.return_value = True
|
||||||
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
|
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
|
||||||
|
|
||||||
# Force BM25 search to fail
|
# Mock BM25 search to fail, triggering fallback
|
||||||
mock_milvus_client.search.side_effect = Exception("BM25 search not available")
|
mock_milvus_client.search.side_effect = Exception("BM25 search not available")
|
||||||
|
|
||||||
# Mock simple text search results
|
# Mock the fallback query to return results
|
||||||
mock_milvus_client.query.return_value = [
|
mock_milvus_client.query.return_value = [
|
||||||
{
|
{
|
||||||
"chunk_id": "chunk1",
|
"chunk_id": "chunk1",
|
||||||
"chunk_content": {"content": "Python programming language", "metadata": {"document_id": "doc1"}},
|
"chunk_content": {"content": "mock chunk 1", "metadata": {"document_id": "doc1"}},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"chunk_id": "chunk2",
|
"chunk_id": "chunk2",
|
||||||
"chunk_content": {"content": "Machine learning algorithms", "metadata": {"document_id": "doc2"}},
|
"chunk_content": {"content": "mock chunk 2", "metadata": {"document_id": "doc2"}},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"chunk_id": "chunk3",
|
||||||
|
"chunk_content": {"content": "mock chunk 3", "metadata": {"document_id": "doc3"}},
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
# Test keyword search that should fall back to simple text search
|
# Test keyword search with fallback
|
||||||
query_string = "Python"
|
query_string = "test query"
|
||||||
response = await milvus_index.query_keyword(query_string=query_string, k=3, score_threshold=0.0)
|
response = await milvus_index.query_keyword(query_string=query_string, k=3, score_threshold=0.0)
|
||||||
|
|
||||||
# Verify response structure
|
# Verify both search and query were called (search failed, query succeeded)
|
||||||
assert isinstance(response, QueryChunksResponse)
|
|
||||||
assert len(response.chunks) > 0, "Fallback search should return results"
|
|
||||||
|
|
||||||
# Verify that simple text search was used (query method called instead of search)
|
|
||||||
mock_milvus_client.query.assert_called_once()
|
mock_milvus_client.query.assert_called_once()
|
||||||
mock_milvus_client.search.assert_called_once() # Called once but failed
|
mock_milvus_client.search.assert_called_once() # Called once but failed
|
||||||
|
|
||||||
# Verify the query uses parameterized filter with filter_params
|
# Verify the query call arguments
|
||||||
query_call_args = mock_milvus_client.query.call_args
|
query_call_args = mock_milvus_client.query.call_args
|
||||||
assert "filter" in query_call_args[1], "Query should include filter for text search"
|
assert query_call_args[1]["collection_name"] == "test_collection"
|
||||||
assert "filter_params" in query_call_args[1], "Query should use parameterized filter"
|
assert "content like" in query_call_args[1]["filter"]
|
||||||
assert query_call_args[1]["filter_params"]["content"] == "Python", "Filter params should contain the search term"
|
|
||||||
|
|
||||||
# Verify all returned chunks have score 1.0 (simple binary scoring)
|
# Verify response is valid
|
||||||
assert all(score == 1.0 for score in response.scores), "Simple text search should use binary scoring"
|
assert isinstance(response, QueryChunksResponse)
|
||||||
|
assert len(response.chunks) == 3
|
||||||
|
|
||||||
|
|
||||||
async def test_delete_collection(milvus_index, mock_milvus_client):
|
async def test_delete_collection(milvus_index, mock_milvus_client):
|
||||||
|
@ -183,3 +224,153 @@ async def test_delete_collection(milvus_index, mock_milvus_client):
|
||||||
await milvus_index.delete()
|
await milvus_index.delete()
|
||||||
|
|
||||||
mock_milvus_client.drop_collection.assert_called_once_with(collection_name=milvus_index.collection_name)
|
mock_milvus_client.drop_collection.assert_called_once_with(collection_name=milvus_index.collection_name)
|
||||||
|
|
||||||
|
|
||||||
|
# Tests for kvstore None handling fix
|
||||||
|
async def test_remote_milvus_initialization_with_kvstore(remote_milvus_config_with_kvstore, mock_inference_api):
|
||||||
|
"""Test that remote Milvus initializes correctly with kvstore configured."""
|
||||||
|
with patch("llama_stack.providers.remote.vector_io.milvus.milvus.MilvusClient") as mock_client_class:
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_client_class.return_value = mock_client
|
||||||
|
|
||||||
|
with patch("llama_stack.providers.remote.vector_io.milvus.milvus.kvstore_impl") as mock_kvstore_impl:
|
||||||
|
mock_kvstore = MagicMock()
|
||||||
|
mock_kvstore_impl.return_value = mock_kvstore
|
||||||
|
mock_kvstore.values_in_range.return_value = asyncio.Future()
|
||||||
|
mock_kvstore.values_in_range.return_value.set_result([])
|
||||||
|
mock_kvstore.set.return_value = asyncio.Future()
|
||||||
|
mock_kvstore.set.return_value.set_result(None)
|
||||||
|
|
||||||
|
adapter = MilvusVectorIOAdapter(
|
||||||
|
config=remote_milvus_config_with_kvstore,
|
||||||
|
inference_api=mock_inference_api,
|
||||||
|
files_api=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
await adapter.initialize()
|
||||||
|
|
||||||
|
# Verify kvstore was initialized
|
||||||
|
mock_kvstore_impl.assert_called_once_with(remote_milvus_config_with_kvstore.kvstore)
|
||||||
|
assert adapter.kvstore is not None
|
||||||
|
|
||||||
|
|
||||||
|
async def test_remote_milvus_initialization_without_kvstore(remote_milvus_config_without_kvstore, mock_inference_api):
|
||||||
|
"""Test that remote Milvus initializes correctly without kvstore (None)."""
|
||||||
|
with patch("llama_stack.providers.remote.vector_io.milvus.milvus.MilvusClient") as mock_client_class:
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_client_class.return_value = mock_client
|
||||||
|
|
||||||
|
adapter = MilvusVectorIOAdapter(
|
||||||
|
config=remote_milvus_config_without_kvstore,
|
||||||
|
inference_api=mock_inference_api,
|
||||||
|
files_api=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
await adapter.initialize()
|
||||||
|
|
||||||
|
# Verify kvstore is None and no kvstore_impl was called
|
||||||
|
assert adapter.kvstore is None
|
||||||
|
|
||||||
|
|
||||||
|
async def test_openai_vector_store_methods_without_kvstore(remote_milvus_config_without_kvstore, mock_inference_api):
|
||||||
|
"""Test that OpenAI vector store methods work correctly when kvstore is None."""
|
||||||
|
with patch("llama_stack.providers.remote.vector_io.milvus.milvus.MilvusClient") as mock_client_class:
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_client_class.return_value = mock_client
|
||||||
|
|
||||||
|
adapter = MilvusVectorIOAdapter(
|
||||||
|
config=remote_milvus_config_without_kvstore,
|
||||||
|
inference_api=mock_inference_api,
|
||||||
|
files_api=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
await adapter.initialize()
|
||||||
|
|
||||||
|
# Test _save_openai_vector_store with None kvstore
|
||||||
|
store_id = "test_store"
|
||||||
|
store_info = {"id": store_id, "name": "test"}
|
||||||
|
|
||||||
|
# Should not raise an error
|
||||||
|
await adapter._save_openai_vector_store(store_id, store_info)
|
||||||
|
|
||||||
|
# Verify store was added to in-memory cache
|
||||||
|
assert store_id in adapter.openai_vector_stores
|
||||||
|
assert adapter.openai_vector_stores[store_id] == store_info
|
||||||
|
|
||||||
|
|
||||||
|
async def test_openai_vector_store_methods_with_kvstore(remote_milvus_config_with_kvstore, mock_inference_api):
|
||||||
|
"""Test that OpenAI vector store methods work correctly when kvstore is configured."""
|
||||||
|
with patch("llama_stack.providers.remote.vector_io.milvus.milvus.MilvusClient") as mock_client_class:
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_client_class.return_value = mock_client
|
||||||
|
|
||||||
|
with patch("llama_stack.providers.remote.vector_io.milvus.milvus.kvstore_impl") as mock_kvstore_impl:
|
||||||
|
mock_kvstore = MagicMock()
|
||||||
|
mock_kvstore_impl.return_value = mock_kvstore
|
||||||
|
mock_kvstore.values_in_range.return_value = asyncio.Future()
|
||||||
|
mock_kvstore.values_in_range.return_value.set_result([])
|
||||||
|
mock_kvstore.set.return_value = asyncio.Future()
|
||||||
|
mock_kvstore.set.return_value.set_result(None)
|
||||||
|
|
||||||
|
adapter = MilvusVectorIOAdapter(
|
||||||
|
config=remote_milvus_config_with_kvstore,
|
||||||
|
inference_api=mock_inference_api,
|
||||||
|
files_api=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
await adapter.initialize()
|
||||||
|
|
||||||
|
# Test _save_openai_vector_store with kvstore
|
||||||
|
store_id = "test_store"
|
||||||
|
store_info = {"id": store_id, "name": "test"}
|
||||||
|
|
||||||
|
await adapter._save_openai_vector_store(store_id, store_info)
|
||||||
|
|
||||||
|
# Verify both kvstore and in-memory cache were updated
|
||||||
|
mock_kvstore.set.assert_called_once()
|
||||||
|
assert store_id in adapter.openai_vector_stores
|
||||||
|
assert adapter.openai_vector_stores[store_id] == store_info
|
||||||
|
|
||||||
|
|
||||||
|
async def test_load_openai_vector_stores_without_kvstore(remote_milvus_config_without_kvstore, mock_inference_api):
|
||||||
|
"""Test that _load_openai_vector_stores returns empty dict when kvstore is None."""
|
||||||
|
with patch("llama_stack.providers.remote.vector_io.milvus.milvus.MilvusClient") as mock_client_class:
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_client_class.return_value = mock_client
|
||||||
|
|
||||||
|
adapter = MilvusVectorIOAdapter(
|
||||||
|
config=remote_milvus_config_without_kvstore,
|
||||||
|
inference_api=mock_inference_api,
|
||||||
|
files_api=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
await adapter.initialize()
|
||||||
|
|
||||||
|
# Should return empty dict when kvstore is None
|
||||||
|
result = await adapter._load_openai_vector_stores()
|
||||||
|
assert result == {}
|
||||||
|
|
||||||
|
|
||||||
|
async def test_delete_openai_vector_store_without_kvstore(remote_milvus_config_without_kvstore, mock_inference_api):
|
||||||
|
"""Test that _delete_openai_vector_store_from_storage works when kvstore is None."""
|
||||||
|
with patch("llama_stack.providers.remote.vector_io.milvus.milvus.MilvusClient") as mock_client_class:
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_client_class.return_value = mock_client
|
||||||
|
|
||||||
|
adapter = MilvusVectorIOAdapter(
|
||||||
|
config=remote_milvus_config_without_kvstore,
|
||||||
|
inference_api=mock_inference_api,
|
||||||
|
files_api=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
await adapter.initialize()
|
||||||
|
|
||||||
|
# Add a store to in-memory cache
|
||||||
|
store_id = "test_store"
|
||||||
|
adapter.openai_vector_stores[store_id] = {"id": store_id}
|
||||||
|
|
||||||
|
# Should not raise an error and should clean up in-memory cache
|
||||||
|
await adapter._delete_openai_vector_store_from_storage(store_id)
|
||||||
|
|
||||||
|
# Verify store was removed from in-memory cache
|
||||||
|
assert store_id not in adapter.openai_vector_stores
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue