mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 02:53:30 +00:00
Merge 2cb927b498
into 40fdce79b3
This commit is contained in:
commit
1ea1f66dee
3 changed files with 247 additions and 6 deletions
|
@ -96,11 +96,20 @@ vector_io:
|
||||||
#### Key Parameters for TLS Configuration
|
#### Key Parameters for TLS Configuration
|
||||||
|
|
||||||
- **`secure`**: Enables TLS encryption when set to `true`. Defaults to `false`.
|
- **`secure`**: Enables TLS encryption when set to `true`. Defaults to `false`.
|
||||||
- **`server_pem_path`**: Path to the **server certificate** for verifying the server’s identity (used in one-way TLS).
|
- **`server_pem_path`**: Path to the **server certificate** for verifying the server's identity (used in one-way TLS).
|
||||||
- **`ca_pem_path`**: Path to the **Certificate Authority (CA) certificate** for validating the server certificate (required in mTLS).
|
- **`ca_pem_path`**: Path to the **Certificate Authority (CA) certificate** for validating the server certificate (required in mTLS).
|
||||||
- **`client_pem_path`**: Path to the **client certificate** file (required for mTLS).
|
- **`client_pem_path`**: Path to the **client certificate** file (required for mTLS).
|
||||||
- **`client_key_path`**: Path to the **client private key** file (required for mTLS).
|
- **`client_key_path`**: Path to the **client private key** file (required for mTLS).
|
||||||
|
|
||||||
|
## Supported Search Modes
|
||||||
|
|
||||||
|
The Milvus provider supports both vector-based and keyword-based (full-text) search modes, but with some limitations:
|
||||||
|
|
||||||
|
- Remote Milvus supports both vector-based and keyword-based search modes.
|
||||||
|
- Inline Milvus (Milvus-Lite) only supports vector-based search. Keyword search is not supported as Milvus-Lite has not implemented this functionality yet. For updates on this feature, see [Milvus GitHub Issue #40848](https://github.com/milvus-io/milvus/issues/40848).
|
||||||
|
|
||||||
|
When using the RAGTool interface, you can specify the desired search behavior via the `mode` parameter in `RAGQueryConfig`. For more details on Milvus's implementation of keyword search modes, refer to the [Milvus documentation](https://milvus.io/docs/full_text_search_with_milvus.md).
|
||||||
|
|
||||||
## Documentation
|
## Documentation
|
||||||
See the [Milvus documentation](https://milvus.io/docs/install-overview.md) for more details about Milvus in general.
|
See the [Milvus documentation](https://milvus.io/docs/install-overview.md) for more details about Milvus in general.
|
||||||
|
|
||||||
|
|
|
@ -12,7 +12,7 @@ import uuid
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from numpy.typing import NDArray
|
from numpy.typing import NDArray
|
||||||
from pymilvus import MilvusClient
|
from pymilvus import DataType, MilvusClient
|
||||||
|
|
||||||
from llama_stack.apis.inference import InterleavedContent
|
from llama_stack.apis.inference import InterleavedContent
|
||||||
from llama_stack.apis.vector_dbs import VectorDB
|
from llama_stack.apis.vector_dbs import VectorDB
|
||||||
|
@ -48,6 +48,8 @@ class MilvusIndex(EmbeddingIndex):
|
||||||
self.client = client
|
self.client = client
|
||||||
self.collection_name = collection_name.replace("-", "_")
|
self.collection_name = collection_name.replace("-", "_")
|
||||||
self.consistency_level = consistency_level
|
self.consistency_level = consistency_level
|
||||||
|
self.bm25 = None
|
||||||
|
self.vectorizer = None
|
||||||
|
|
||||||
async def delete(self):
|
async def delete(self):
|
||||||
if await asyncio.to_thread(self.client.has_collection, self.collection_name):
|
if await asyncio.to_thread(self.client.has_collection, self.collection_name):
|
||||||
|
@ -58,11 +60,42 @@ class MilvusIndex(EmbeddingIndex):
|
||||||
f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}"
|
f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}"
|
||||||
)
|
)
|
||||||
if not await asyncio.to_thread(self.client.has_collection, self.collection_name):
|
if not await asyncio.to_thread(self.client.has_collection, self.collection_name):
|
||||||
|
# Create schema for vector search
|
||||||
|
schema = self.client.create_schema()
|
||||||
|
schema.add_field(
|
||||||
|
field_name="chunk_id",
|
||||||
|
datatype=DataType.VARCHAR,
|
||||||
|
is_primary=True,
|
||||||
|
max_length=100,
|
||||||
|
)
|
||||||
|
schema.add_field(
|
||||||
|
field_name="content",
|
||||||
|
datatype=DataType.VARCHAR,
|
||||||
|
max_length=65535,
|
||||||
|
)
|
||||||
|
schema.add_field(
|
||||||
|
field_name="vector",
|
||||||
|
datatype=DataType.FLOAT_VECTOR,
|
||||||
|
dim=len(embeddings[0]),
|
||||||
|
)
|
||||||
|
schema.add_field(
|
||||||
|
field_name="chunk_content",
|
||||||
|
datatype=DataType.JSON,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create indexes
|
||||||
|
index_params = self.client.prepare_index_params()
|
||||||
|
index_params.add_index(
|
||||||
|
field_name="vector",
|
||||||
|
index_type="FLAT",
|
||||||
|
metric_type="COSINE",
|
||||||
|
)
|
||||||
|
|
||||||
await asyncio.to_thread(
|
await asyncio.to_thread(
|
||||||
self.client.create_collection,
|
self.client.create_collection,
|
||||||
self.collection_name,
|
self.collection_name,
|
||||||
dimension=len(embeddings[0]),
|
schema=schema,
|
||||||
auto_id=True,
|
index_params=index_params,
|
||||||
consistency_level=self.consistency_level,
|
consistency_level=self.consistency_level,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -73,6 +106,7 @@ class MilvusIndex(EmbeddingIndex):
|
||||||
data.append(
|
data.append(
|
||||||
{
|
{
|
||||||
"chunk_id": chunk_id,
|
"chunk_id": chunk_id,
|
||||||
|
"content": chunk.content,
|
||||||
"vector": embedding,
|
"vector": embedding,
|
||||||
"chunk_content": chunk.model_dump(),
|
"chunk_content": chunk.model_dump(),
|
||||||
}
|
}
|
||||||
|
@ -92,9 +126,10 @@ class MilvusIndex(EmbeddingIndex):
|
||||||
self.client.search,
|
self.client.search,
|
||||||
collection_name=self.collection_name,
|
collection_name=self.collection_name,
|
||||||
data=[embedding],
|
data=[embedding],
|
||||||
|
anns_field="vector",
|
||||||
limit=k,
|
limit=k,
|
||||||
output_fields=["*"],
|
output_fields=["*"],
|
||||||
search_params={"params": {"radius": score_threshold}},
|
search_params={"metric_type": "COSINE", "params": {"score_threshold": score_threshold}},
|
||||||
)
|
)
|
||||||
chunks = [Chunk(**res["entity"]["chunk_content"]) for res in search_res[0]]
|
chunks = [Chunk(**res["entity"]["chunk_content"]) for res in search_res[0]]
|
||||||
scores = [res["distance"] for res in search_res[0]]
|
scores = [res["distance"] for res in search_res[0]]
|
||||||
|
@ -106,7 +141,17 @@ class MilvusIndex(EmbeddingIndex):
|
||||||
k: int,
|
k: int,
|
||||||
score_threshold: float,
|
score_threshold: float,
|
||||||
) -> QueryChunksResponse:
|
) -> QueryChunksResponse:
|
||||||
raise NotImplementedError("Keyword search is not supported in Milvus")
|
# Simple text search using content field
|
||||||
|
search_res = await asyncio.to_thread(
|
||||||
|
self.client.query,
|
||||||
|
collection_name=self.collection_name,
|
||||||
|
filter=f'content like "%{query_string}%"',
|
||||||
|
output_fields=["*"],
|
||||||
|
limit=k,
|
||||||
|
)
|
||||||
|
chunks = [Chunk(**res["chunk_content"]) for res in search_res]
|
||||||
|
scores = [1.0] * len(chunks) # Simple binary score for text search
|
||||||
|
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||||
|
|
||||||
async def query_hybrid(
|
async def query_hybrid(
|
||||||
self,
|
self,
|
||||||
|
@ -200,6 +245,14 @@ class MilvusVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
||||||
if not index:
|
if not index:
|
||||||
raise ValueError(f"Vector DB {vector_db_id} not found")
|
raise ValueError(f"Vector DB {vector_db_id} not found")
|
||||||
|
|
||||||
|
if params and params.get("mode") == "keyword":
|
||||||
|
# Check if this is inline Milvus (Milvus-Lite)
|
||||||
|
if hasattr(self.config, "db_path"):
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Keyword search is not supported in Milvus-Lite. "
|
||||||
|
"Please use a remote Milvus server for keyword search functionality."
|
||||||
|
)
|
||||||
|
|
||||||
return await index.query_chunks(query, params)
|
return await index.query_chunks(query, params)
|
||||||
|
|
||||||
async def openai_create_vector_store(
|
async def openai_create_vector_store(
|
||||||
|
|
179
tests/unit/providers/vector_io/remote/test_milvus.py
Normal file
179
tests/unit/providers/vector_io/remote/test_milvus.py
Normal file
|
@ -0,0 +1,179 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
|
||||||
|
from llama_stack.apis.vector_io import QueryChunksResponse
|
||||||
|
|
||||||
|
# Mock the entire pymilvus module
|
||||||
|
pymilvus_mock = MagicMock()
|
||||||
|
pymilvus_mock.DataType = MagicMock()
|
||||||
|
pymilvus_mock.MilvusClient = MagicMock
|
||||||
|
|
||||||
|
# Apply the mock before importing MilvusIndex
|
||||||
|
with patch.dict("sys.modules", {"pymilvus": pymilvus_mock}):
|
||||||
|
from llama_stack.providers.remote.vector_io.milvus.milvus import MilvusIndex
|
||||||
|
|
||||||
|
# This test is a unit test for the MilvusVectorIOAdapter class. This should only contain
|
||||||
|
# tests which are specific to this class. More general (API-level) tests should be placed in
|
||||||
|
# tests/integration/vector_io/
|
||||||
|
#
|
||||||
|
# How to run this test:
|
||||||
|
#
|
||||||
|
# pytest tests/unit/providers/vector_io/test_milvus.py \
|
||||||
|
# -v -s --tb=short --disable-warnings --asyncio-mode=auto
|
||||||
|
|
||||||
|
MILVUS_PROVIDER = "milvus"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def mock_milvus_client():
|
||||||
|
"""Create a mock Milvus client with common method behaviors."""
|
||||||
|
client = MagicMock()
|
||||||
|
|
||||||
|
# Mock collection operations
|
||||||
|
client.has_collection.return_value = False # Initially no collection
|
||||||
|
client.create_collection.return_value = None
|
||||||
|
client.drop_collection.return_value = None
|
||||||
|
|
||||||
|
# Mock insert operation
|
||||||
|
client.insert.return_value = {"insert_count": 10}
|
||||||
|
|
||||||
|
# Mock search operation - return mock results (data should be dict, not JSON string)
|
||||||
|
client.search.return_value = [
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"id": 0,
|
||||||
|
"distance": 0.1,
|
||||||
|
"entity": {"chunk_content": {"content": "mock chunk 1", "metadata": {"document_id": "doc1"}}},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1,
|
||||||
|
"distance": 0.2,
|
||||||
|
"entity": {"chunk_content": {"content": "mock chunk 2", "metadata": {"document_id": "doc2"}}},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
]
|
||||||
|
|
||||||
|
# Mock query operation for keyword search (data should be dict, not JSON string)
|
||||||
|
client.query.return_value = [
|
||||||
|
{
|
||||||
|
"chunk_id": "chunk1",
|
||||||
|
"chunk_content": {"content": "mock chunk 1", "metadata": {"document_id": "doc1"}},
|
||||||
|
"score": 0.9,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"chunk_id": "chunk2",
|
||||||
|
"chunk_content": {"content": "mock chunk 2", "metadata": {"document_id": "doc2"}},
|
||||||
|
"score": 0.8,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"chunk_id": "chunk3",
|
||||||
|
"chunk_content": {"content": "mock chunk 3", "metadata": {"document_id": "doc3"}},
|
||||||
|
"score": 0.7,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
return client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def milvus_index(mock_milvus_client):
|
||||||
|
"""Create a MilvusIndex with mocked client."""
|
||||||
|
index = MilvusIndex(client=mock_milvus_client, collection_name="test_collection")
|
||||||
|
yield index
|
||||||
|
# No real cleanup needed since we're using mocks
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_add_chunks(milvus_index, sample_chunks, sample_embeddings, mock_milvus_client):
|
||||||
|
# Setup: collection doesn't exist initially, then exists after creation
|
||||||
|
mock_milvus_client.has_collection.side_effect = [False, True]
|
||||||
|
|
||||||
|
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
|
||||||
|
|
||||||
|
# Verify collection was created and data was inserted
|
||||||
|
mock_milvus_client.create_collection.assert_called_once()
|
||||||
|
mock_milvus_client.insert.assert_called_once()
|
||||||
|
|
||||||
|
# Verify the insert call had the right number of chunks
|
||||||
|
insert_call = mock_milvus_client.insert.call_args
|
||||||
|
assert len(insert_call[1]["data"]) == len(sample_chunks)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_query_chunks_vector(
|
||||||
|
milvus_index, sample_chunks, sample_embeddings, embedding_dimension, mock_milvus_client
|
||||||
|
):
|
||||||
|
# Setup: Add chunks first
|
||||||
|
mock_milvus_client.has_collection.return_value = True
|
||||||
|
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
|
||||||
|
|
||||||
|
# Test vector search
|
||||||
|
query_embedding = np.random.rand(embedding_dimension).astype(np.float32)
|
||||||
|
response = await milvus_index.query_vector(query_embedding, k=2, score_threshold=0.0)
|
||||||
|
|
||||||
|
assert isinstance(response, QueryChunksResponse)
|
||||||
|
assert len(response.chunks) == 2
|
||||||
|
mock_milvus_client.search.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_query_chunks_keyword_search(milvus_index, sample_chunks, sample_embeddings, mock_milvus_client):
|
||||||
|
mock_milvus_client.has_collection.return_value = True
|
||||||
|
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
|
||||||
|
|
||||||
|
# Test keyword search
|
||||||
|
query_string = "Sentence 5"
|
||||||
|
response = await milvus_index.query_keyword(query_string=query_string, k=3, score_threshold=0.0)
|
||||||
|
|
||||||
|
assert isinstance(response, QueryChunksResponse)
|
||||||
|
assert len(response.chunks) == 3
|
||||||
|
mock_milvus_client.query.assert_called_once()
|
||||||
|
|
||||||
|
# Test no results case
|
||||||
|
mock_milvus_client.query.return_value = []
|
||||||
|
response_no_results = await milvus_index.query_keyword(query_string="nonexistent", k=1, score_threshold=0.0)
|
||||||
|
|
||||||
|
assert isinstance(response_no_results, QueryChunksResponse)
|
||||||
|
assert len(response_no_results.chunks) == 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_query_chunks_keyword_search_k_greater_than_results(
|
||||||
|
milvus_index, sample_chunks, sample_embeddings, mock_milvus_client
|
||||||
|
):
|
||||||
|
mock_milvus_client.has_collection.return_value = True
|
||||||
|
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
|
||||||
|
|
||||||
|
# Mock returning only 1 result even though k=5
|
||||||
|
mock_milvus_client.query.return_value = [
|
||||||
|
{
|
||||||
|
"chunk_id": "chunk1",
|
||||||
|
"chunk_content": {"content": "Sentence 1 from document 0", "metadata": {"document_id": "doc1"}},
|
||||||
|
"score": 0.9,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
query_str = "Sentence 1 from document 0"
|
||||||
|
response = await milvus_index.query_keyword(query_string=query_str, k=5, score_threshold=0.0)
|
||||||
|
|
||||||
|
assert 0 < len(response.chunks) <= 4
|
||||||
|
assert any("Sentence 1 from document 0" in chunk.content for chunk in response.chunks)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_delete_collection(milvus_index, mock_milvus_client):
|
||||||
|
# Test collection deletion
|
||||||
|
mock_milvus_client.has_collection.return_value = True
|
||||||
|
|
||||||
|
await milvus_index.delete()
|
||||||
|
|
||||||
|
mock_milvus_client.drop_collection.assert_called_once_with(collection_name=milvus_index.collection_name)
|
Loading…
Add table
Add a link
Reference in a new issue