mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-10 13:28:40 +00:00
chore: Removing Weaviate, PGVector, and Milvus from unit tests (#3742)
Some checks failed
SqlStore Integration Tests / test-postgres (3.12) (push) Failing after 0s
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 1s
Test External Providers Installed via Module / test-external-providers-from-module (venv) (push) Has been skipped
Python Package Build Test / build (3.12) (push) Failing after 1s
Unit Tests / unit-tests (3.13) (push) Failing after 3s
Python Package Build Test / build (3.13) (push) Failing after 1s
SqlStore Integration Tests / test-postgres (3.13) (push) Failing after 4s
Integration Tests (Replay) / Integration Tests (, , , client=, ) (push) Failing after 5s
Vector IO Integration Tests / test-matrix (push) Failing after 5s
Unit Tests / unit-tests (3.12) (push) Failing after 3s
Test External API and Providers / test-external (venv) (push) Failing after 3s
API Conformance Tests / check-schema-compatibility (push) Successful in 11s
UI Tests / ui-tests (22) (push) Successful in 48s
Pre-commit / pre-commit (push) Successful in 1m27s
Some checks failed
SqlStore Integration Tests / test-postgres (3.12) (push) Failing after 0s
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 1s
Test External Providers Installed via Module / test-external-providers-from-module (venv) (push) Has been skipped
Python Package Build Test / build (3.12) (push) Failing after 1s
Unit Tests / unit-tests (3.13) (push) Failing after 3s
Python Package Build Test / build (3.13) (push) Failing after 1s
SqlStore Integration Tests / test-postgres (3.13) (push) Failing after 4s
Integration Tests (Replay) / Integration Tests (, , , client=, ) (push) Failing after 5s
Vector IO Integration Tests / test-matrix (push) Failing after 5s
Unit Tests / unit-tests (3.12) (push) Failing after 3s
Test External API and Providers / test-external (venv) (push) Failing after 3s
API Conformance Tests / check-schema-compatibility (push) Successful in 11s
UI Tests / ui-tests (22) (push) Successful in 48s
Pre-commit / pre-commit (push) Successful in 1m27s
# What does this PR do? Removing Weaviate, PostGres, and Milvus unit tests <!-- If resolving an issue, uncomment and update the line below --> <!-- Closes #[issue-number] --> ## Test Plan <!-- Describe the tests you ran to verify your changes with result summaries. *Provide clear instructions so the plan can be easily re-executed.* --> Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
This commit is contained in:
parent
79bed44b04
commit
b96640eca3
4 changed files with 3 additions and 579 deletions
|
@ -10,31 +10,26 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
from chromadb import PersistentClient
|
from chromadb import PersistentClient
|
||||||
from pymilvus import MilvusClient, connections
|
|
||||||
|
|
||||||
from llama_stack.apis.vector_dbs import VectorDB
|
from llama_stack.apis.vector_dbs import VectorDB
|
||||||
from llama_stack.apis.vector_io import Chunk, ChunkMetadata, QueryChunksResponse
|
from llama_stack.apis.vector_io import Chunk, ChunkMetadata, QueryChunksResponse
|
||||||
from llama_stack.providers.inline.vector_io.chroma.config import ChromaVectorIOConfig
|
from llama_stack.providers.inline.vector_io.chroma.config import ChromaVectorIOConfig
|
||||||
from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig
|
from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig
|
||||||
from llama_stack.providers.inline.vector_io.faiss.faiss import FaissIndex, FaissVectorIOAdapter
|
from llama_stack.providers.inline.vector_io.faiss.faiss import FaissIndex, FaissVectorIOAdapter
|
||||||
from llama_stack.providers.inline.vector_io.milvus.config import MilvusVectorIOConfig, SqliteKVStoreConfig
|
from llama_stack.providers.inline.vector_io.milvus.config import SqliteKVStoreConfig
|
||||||
from llama_stack.providers.inline.vector_io.qdrant import QdrantVectorIOConfig
|
from llama_stack.providers.inline.vector_io.qdrant import QdrantVectorIOConfig
|
||||||
from llama_stack.providers.inline.vector_io.sqlite_vec import SQLiteVectorIOConfig
|
from llama_stack.providers.inline.vector_io.sqlite_vec import SQLiteVectorIOConfig
|
||||||
from llama_stack.providers.inline.vector_io.sqlite_vec.sqlite_vec import SQLiteVecIndex, SQLiteVecVectorIOAdapter
|
from llama_stack.providers.inline.vector_io.sqlite_vec.sqlite_vec import SQLiteVecIndex, SQLiteVecVectorIOAdapter
|
||||||
from llama_stack.providers.remote.vector_io.chroma.chroma import ChromaIndex, ChromaVectorIOAdapter, maybe_await
|
from llama_stack.providers.remote.vector_io.chroma.chroma import ChromaIndex, ChromaVectorIOAdapter, maybe_await
|
||||||
from llama_stack.providers.remote.vector_io.milvus.milvus import MilvusIndex, MilvusVectorIOAdapter
|
|
||||||
from llama_stack.providers.remote.vector_io.pgvector.config import PGVectorVectorIOConfig
|
from llama_stack.providers.remote.vector_io.pgvector.config import PGVectorVectorIOConfig
|
||||||
from llama_stack.providers.remote.vector_io.pgvector.pgvector import PGVectorIndex, PGVectorVectorIOAdapter
|
from llama_stack.providers.remote.vector_io.pgvector.pgvector import PGVectorIndex, PGVectorVectorIOAdapter
|
||||||
from llama_stack.providers.remote.vector_io.qdrant.qdrant import QdrantVectorIOAdapter
|
from llama_stack.providers.remote.vector_io.qdrant.qdrant import QdrantVectorIOAdapter
|
||||||
from llama_stack.providers.remote.vector_io.weaviate.config import WeaviateVectorIOConfig
|
|
||||||
from llama_stack.providers.remote.vector_io.weaviate.weaviate import WeaviateIndex, WeaviateVectorIOAdapter
|
|
||||||
|
|
||||||
EMBEDDING_DIMENSION = 384
|
EMBEDDING_DIMENSION = 384
|
||||||
COLLECTION_PREFIX = "test_collection"
|
COLLECTION_PREFIX = "test_collection"
|
||||||
MILVUS_ALIAS = "test_milvus"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(params=["milvus", "sqlite_vec", "faiss", "chroma", "pgvector", "weaviate"])
|
@pytest.fixture(params=["sqlite_vec", "faiss", "chroma", "pgvector"])
|
||||||
def vector_provider(request):
|
def vector_provider(request):
|
||||||
return request.param
|
return request.param
|
||||||
|
|
||||||
|
@ -170,46 +165,6 @@ async def sqlite_vec_adapter(sqlite_vec_db_path, unique_kvstore_config, mock_inf
|
||||||
await adapter.shutdown()
|
await adapter.shutdown()
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
|
||||||
def milvus_vec_db_path(tmp_path_factory):
|
|
||||||
db_path = str(tmp_path_factory.getbasetemp() / "test_milvus.db")
|
|
||||||
return db_path
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
async def milvus_vec_index(milvus_vec_db_path, embedding_dimension):
|
|
||||||
client = MilvusClient(milvus_vec_db_path)
|
|
||||||
name = f"{COLLECTION_PREFIX}_{np.random.randint(1e6)}"
|
|
||||||
connections.connect(alias=MILVUS_ALIAS, uri=milvus_vec_db_path)
|
|
||||||
index = MilvusIndex(client, name, consistency_level="Strong")
|
|
||||||
index.db_path = milvus_vec_db_path
|
|
||||||
yield index
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
async def milvus_vec_adapter(milvus_vec_db_path, unique_kvstore_config, mock_inference_api):
|
|
||||||
config = MilvusVectorIOConfig(
|
|
||||||
db_path=milvus_vec_db_path,
|
|
||||||
kvstore=unique_kvstore_config,
|
|
||||||
)
|
|
||||||
adapter = MilvusVectorIOAdapter(
|
|
||||||
config=config,
|
|
||||||
inference_api=mock_inference_api,
|
|
||||||
files_api=None,
|
|
||||||
)
|
|
||||||
await adapter.initialize()
|
|
||||||
await adapter.register_vector_db(
|
|
||||||
VectorDB(
|
|
||||||
identifier=adapter.metadata_collection_name,
|
|
||||||
provider_id="test_provider",
|
|
||||||
embedding_model="test_model",
|
|
||||||
embedding_dimension=128,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
yield adapter
|
|
||||||
await adapter.shutdown()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def faiss_vec_db_path(tmp_path_factory):
|
def faiss_vec_db_path(tmp_path_factory):
|
||||||
db_path = str(tmp_path_factory.getbasetemp() / "test_faiss.db")
|
db_path = str(tmp_path_factory.getbasetemp() / "test_faiss.db")
|
||||||
|
@ -450,81 +405,14 @@ async def pgvector_vec_adapter(unique_kvstore_config, mock_inference_api, embedd
|
||||||
await adapter.shutdown()
|
await adapter.shutdown()
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
|
||||||
def weaviate_vec_db_path(tmp_path_factory):
|
|
||||||
db_path = str(tmp_path_factory.getbasetemp() / "test_weaviate.db")
|
|
||||||
return db_path
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
async def weaviate_vec_index(weaviate_vec_db_path):
|
|
||||||
import pytest_socket
|
|
||||||
import weaviate
|
|
||||||
|
|
||||||
pytest_socket.enable_socket()
|
|
||||||
client = weaviate.connect_to_embedded(
|
|
||||||
hostname="localhost",
|
|
||||||
port=8080,
|
|
||||||
grpc_port=50051,
|
|
||||||
persistence_data_path=weaviate_vec_db_path,
|
|
||||||
)
|
|
||||||
index = WeaviateIndex(client=client, collection_name="Testcollection")
|
|
||||||
await index.initialize()
|
|
||||||
yield index
|
|
||||||
await index.delete()
|
|
||||||
client.close()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
async def weaviate_vec_adapter(weaviate_vec_db_path, unique_kvstore_config, mock_inference_api, embedding_dimension):
|
|
||||||
import pytest_socket
|
|
||||||
import weaviate
|
|
||||||
|
|
||||||
pytest_socket.enable_socket()
|
|
||||||
|
|
||||||
client = weaviate.connect_to_embedded(
|
|
||||||
hostname="localhost",
|
|
||||||
port=8080,
|
|
||||||
grpc_port=50051,
|
|
||||||
persistence_data_path=weaviate_vec_db_path,
|
|
||||||
)
|
|
||||||
|
|
||||||
config = WeaviateVectorIOConfig(
|
|
||||||
weaviate_cluster_url="localhost:8080",
|
|
||||||
weaviate_api_key=None,
|
|
||||||
kvstore=unique_kvstore_config,
|
|
||||||
)
|
|
||||||
adapter = WeaviateVectorIOAdapter(
|
|
||||||
config=config,
|
|
||||||
inference_api=mock_inference_api,
|
|
||||||
files_api=None,
|
|
||||||
)
|
|
||||||
collection_id = f"weaviate_test_collection_{random.randint(1, 1_000_000)}"
|
|
||||||
await adapter.initialize()
|
|
||||||
await adapter.register_vector_db(
|
|
||||||
VectorDB(
|
|
||||||
identifier=collection_id,
|
|
||||||
provider_id="test_provider",
|
|
||||||
embedding_model="test_model",
|
|
||||||
embedding_dimension=embedding_dimension,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
adapter.test_collection_id = collection_id
|
|
||||||
yield adapter
|
|
||||||
await adapter.shutdown()
|
|
||||||
client.close()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def vector_io_adapter(vector_provider, request):
|
def vector_io_adapter(vector_provider, request):
|
||||||
vector_provider_dict = {
|
vector_provider_dict = {
|
||||||
"milvus": "milvus_vec_adapter",
|
|
||||||
"faiss": "faiss_vec_adapter",
|
"faiss": "faiss_vec_adapter",
|
||||||
"sqlite_vec": "sqlite_vec_adapter",
|
"sqlite_vec": "sqlite_vec_adapter",
|
||||||
"chroma": "chroma_vec_adapter",
|
"chroma": "chroma_vec_adapter",
|
||||||
"qdrant": "qdrant_vec_adapter",
|
"qdrant": "qdrant_vec_adapter",
|
||||||
"pgvector": "pgvector_vec_adapter",
|
"pgvector": "pgvector_vec_adapter",
|
||||||
"weaviate": "weaviate_vec_adapter",
|
|
||||||
}
|
}
|
||||||
return request.getfixturevalue(vector_provider_dict[vector_provider])
|
return request.getfixturevalue(vector_provider_dict[vector_provider])
|
||||||
|
|
||||||
|
|
|
@ -1,326 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
from unittest.mock import MagicMock, patch
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from llama_stack.apis.vector_io import QueryChunksResponse
|
|
||||||
|
|
||||||
# Mock the entire pymilvus module
|
|
||||||
pymilvus_mock = MagicMock()
|
|
||||||
pymilvus_mock.DataType = MagicMock()
|
|
||||||
pymilvus_mock.MilvusClient = MagicMock
|
|
||||||
pymilvus_mock.RRFRanker = MagicMock
|
|
||||||
pymilvus_mock.WeightedRanker = MagicMock
|
|
||||||
pymilvus_mock.AnnSearchRequest = MagicMock
|
|
||||||
|
|
||||||
# Apply the mock before importing MilvusIndex
|
|
||||||
with patch.dict("sys.modules", {"pymilvus": pymilvus_mock}):
|
|
||||||
from llama_stack.providers.remote.vector_io.milvus.milvus import MilvusIndex
|
|
||||||
|
|
||||||
# This test is a unit test for the MilvusVectorIOAdapter class. This should only contain
|
|
||||||
# tests which are specific to this class. More general (API-level) tests should be placed in
|
|
||||||
# tests/integration/vector_io/
|
|
||||||
#
|
|
||||||
# How to run this test:
|
|
||||||
#
|
|
||||||
# pytest tests/unit/providers/vector_io/test_milvus.py \
|
|
||||||
# -v -s --tb=short --disable-warnings --asyncio-mode=auto
|
|
||||||
|
|
||||||
MILVUS_PROVIDER = "milvus"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
async def mock_milvus_client() -> MagicMock:
|
|
||||||
"""Create a mock Milvus client with common method behaviors."""
|
|
||||||
client = MagicMock()
|
|
||||||
|
|
||||||
# Mock collection operations
|
|
||||||
client.has_collection.return_value = False # Initially no collection
|
|
||||||
client.create_collection.return_value = None
|
|
||||||
client.drop_collection.return_value = None
|
|
||||||
|
|
||||||
# Mock insert operation
|
|
||||||
client.insert.return_value = {"insert_count": 10}
|
|
||||||
|
|
||||||
# Mock search operation - return mock results (data should be dict, not JSON string)
|
|
||||||
client.search.return_value = [
|
|
||||||
[
|
|
||||||
{
|
|
||||||
"id": 0,
|
|
||||||
"distance": 0.1,
|
|
||||||
"entity": {"chunk_content": {"content": "mock chunk 1", "metadata": {"document_id": "doc1"}}},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 1,
|
|
||||||
"distance": 0.2,
|
|
||||||
"entity": {"chunk_content": {"content": "mock chunk 2", "metadata": {"document_id": "doc2"}}},
|
|
||||||
},
|
|
||||||
]
|
|
||||||
]
|
|
||||||
|
|
||||||
# Mock query operation for keyword search (data should be dict, not JSON string)
|
|
||||||
client.query.return_value = [
|
|
||||||
{
|
|
||||||
"chunk_id": "chunk1",
|
|
||||||
"chunk_content": {"content": "mock chunk 1", "metadata": {"document_id": "doc1"}},
|
|
||||||
"score": 0.9,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"chunk_id": "chunk2",
|
|
||||||
"chunk_content": {"content": "mock chunk 2", "metadata": {"document_id": "doc2"}},
|
|
||||||
"score": 0.8,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"chunk_id": "chunk3",
|
|
||||||
"chunk_content": {"content": "mock chunk 3", "metadata": {"document_id": "doc3"}},
|
|
||||||
"score": 0.7,
|
|
||||||
},
|
|
||||||
]
|
|
||||||
|
|
||||||
return client
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
async def milvus_index(mock_milvus_client):
|
|
||||||
"""Create a MilvusIndex with mocked client."""
|
|
||||||
index = MilvusIndex(client=mock_milvus_client, collection_name="test_collection")
|
|
||||||
yield index
|
|
||||||
# No real cleanup needed since we're using mocks
|
|
||||||
|
|
||||||
|
|
||||||
async def test_add_chunks(milvus_index, sample_chunks, sample_embeddings, mock_milvus_client):
|
|
||||||
# Setup: collection doesn't exist initially, then exists after creation
|
|
||||||
mock_milvus_client.has_collection.side_effect = [False, True]
|
|
||||||
|
|
||||||
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
|
|
||||||
|
|
||||||
# Verify collection was created and data was inserted
|
|
||||||
mock_milvus_client.create_collection.assert_called_once()
|
|
||||||
mock_milvus_client.insert.assert_called_once()
|
|
||||||
|
|
||||||
# Verify the insert call had the right number of chunks
|
|
||||||
insert_call = mock_milvus_client.insert.call_args
|
|
||||||
assert len(insert_call[1]["data"]) == len(sample_chunks)
|
|
||||||
|
|
||||||
|
|
||||||
async def test_query_chunks_vector(
|
|
||||||
milvus_index, sample_chunks, sample_embeddings, embedding_dimension, mock_milvus_client
|
|
||||||
):
|
|
||||||
# Setup: Add chunks first
|
|
||||||
mock_milvus_client.has_collection.return_value = True
|
|
||||||
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
|
|
||||||
|
|
||||||
# Test vector search
|
|
||||||
query_embedding = np.random.rand(embedding_dimension).astype(np.float32)
|
|
||||||
response = await milvus_index.query_vector(query_embedding, k=2, score_threshold=0.0)
|
|
||||||
|
|
||||||
assert isinstance(response, QueryChunksResponse)
|
|
||||||
assert len(response.chunks) == 2
|
|
||||||
mock_milvus_client.search.assert_called_once()
|
|
||||||
|
|
||||||
|
|
||||||
async def test_query_chunks_keyword_search(milvus_index, sample_chunks, sample_embeddings, mock_milvus_client):
|
|
||||||
mock_milvus_client.has_collection.return_value = True
|
|
||||||
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
|
|
||||||
|
|
||||||
# Test keyword search
|
|
||||||
query_string = "Sentence 5"
|
|
||||||
response = await milvus_index.query_keyword(query_string=query_string, k=2, score_threshold=0.0)
|
|
||||||
|
|
||||||
assert isinstance(response, QueryChunksResponse)
|
|
||||||
assert len(response.chunks) == 2
|
|
||||||
|
|
||||||
|
|
||||||
async def test_bm25_fallback_to_simple_search(milvus_index, sample_chunks, sample_embeddings, mock_milvus_client):
|
|
||||||
"""Test that when BM25 search fails, the system falls back to simple text search."""
|
|
||||||
mock_milvus_client.has_collection.return_value = True
|
|
||||||
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
|
|
||||||
|
|
||||||
# Force BM25 search to fail
|
|
||||||
mock_milvus_client.search.side_effect = Exception("BM25 search not available")
|
|
||||||
|
|
||||||
# Mock simple text search results
|
|
||||||
mock_milvus_client.query.return_value = [
|
|
||||||
{
|
|
||||||
"chunk_id": "chunk1",
|
|
||||||
"chunk_content": {"content": "Python programming language", "metadata": {"document_id": "doc1"}},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"chunk_id": "chunk2",
|
|
||||||
"chunk_content": {"content": "Machine learning algorithms", "metadata": {"document_id": "doc2"}},
|
|
||||||
},
|
|
||||||
]
|
|
||||||
|
|
||||||
# Test keyword search that should fall back to simple text search
|
|
||||||
query_string = "Python"
|
|
||||||
response = await milvus_index.query_keyword(query_string=query_string, k=3, score_threshold=0.0)
|
|
||||||
|
|
||||||
# Verify response structure
|
|
||||||
assert isinstance(response, QueryChunksResponse)
|
|
||||||
assert len(response.chunks) > 0, "Fallback search should return results"
|
|
||||||
|
|
||||||
# Verify that simple text search was used (query method called instead of search)
|
|
||||||
mock_milvus_client.query.assert_called_once()
|
|
||||||
mock_milvus_client.search.assert_called_once() # Called once but failed
|
|
||||||
|
|
||||||
# Verify the query uses parameterized filter with filter_params
|
|
||||||
query_call_args = mock_milvus_client.query.call_args
|
|
||||||
assert "filter" in query_call_args[1], "Query should include filter for text search"
|
|
||||||
assert "filter_params" in query_call_args[1], "Query should use parameterized filter"
|
|
||||||
assert query_call_args[1]["filter_params"]["content"] == "Python", "Filter params should contain the search term"
|
|
||||||
|
|
||||||
# Verify all returned chunks have score 1.0 (simple binary scoring)
|
|
||||||
assert all(score == 1.0 for score in response.scores), "Simple text search should use binary scoring"
|
|
||||||
|
|
||||||
|
|
||||||
async def test_delete_collection(milvus_index, mock_milvus_client):
|
|
||||||
# Test collection deletion
|
|
||||||
mock_milvus_client.has_collection.return_value = True
|
|
||||||
|
|
||||||
await milvus_index.delete()
|
|
||||||
|
|
||||||
mock_milvus_client.drop_collection.assert_called_once_with(collection_name=milvus_index.collection_name)
|
|
||||||
|
|
||||||
|
|
||||||
async def test_query_hybrid_search_rrf(
|
|
||||||
milvus_index, sample_chunks, sample_embeddings, embedding_dimension, mock_milvus_client
|
|
||||||
):
|
|
||||||
"""Test hybrid search with RRF reranker."""
|
|
||||||
mock_milvus_client.has_collection.return_value = True
|
|
||||||
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
|
|
||||||
|
|
||||||
# Mock hybrid search results
|
|
||||||
mock_milvus_client.hybrid_search.return_value = [
|
|
||||||
[
|
|
||||||
{
|
|
||||||
"id": 0,
|
|
||||||
"distance": 0.1,
|
|
||||||
"entity": {"chunk_content": {"content": "mock chunk 1", "metadata": {"document_id": "doc1"}}},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 1,
|
|
||||||
"distance": 0.2,
|
|
||||||
"entity": {"chunk_content": {"content": "mock chunk 2", "metadata": {"document_id": "doc2"}}},
|
|
||||||
},
|
|
||||||
]
|
|
||||||
]
|
|
||||||
|
|
||||||
# Test hybrid search with RRF reranker
|
|
||||||
query_embedding = np.random.rand(embedding_dimension).astype(np.float32)
|
|
||||||
query_string = "test query"
|
|
||||||
response = await milvus_index.query_hybrid(
|
|
||||||
embedding=query_embedding,
|
|
||||||
query_string=query_string,
|
|
||||||
k=2,
|
|
||||||
score_threshold=0.0,
|
|
||||||
reranker_type="rrf",
|
|
||||||
reranker_params={"impact_factor": 60.0},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert isinstance(response, QueryChunksResponse)
|
|
||||||
assert len(response.chunks) == 2
|
|
||||||
assert len(response.scores) == 2
|
|
||||||
|
|
||||||
# Verify hybrid search was called with correct parameters
|
|
||||||
mock_milvus_client.hybrid_search.assert_called_once()
|
|
||||||
call_args = mock_milvus_client.hybrid_search.call_args
|
|
||||||
|
|
||||||
# Check that the request contains both vector and BM25 search requests
|
|
||||||
reqs = call_args[1]["reqs"]
|
|
||||||
assert len(reqs) == 2
|
|
||||||
assert reqs[0].anns_field == "vector"
|
|
||||||
assert reqs[1].anns_field == "sparse"
|
|
||||||
ranker = call_args[1]["ranker"]
|
|
||||||
assert ranker is not None
|
|
||||||
|
|
||||||
|
|
||||||
async def test_query_hybrid_search_weighted(
|
|
||||||
milvus_index, sample_chunks, sample_embeddings, embedding_dimension, mock_milvus_client
|
|
||||||
):
|
|
||||||
"""Test hybrid search with weighted reranker."""
|
|
||||||
mock_milvus_client.has_collection.return_value = True
|
|
||||||
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
|
|
||||||
|
|
||||||
# Mock hybrid search results
|
|
||||||
mock_milvus_client.hybrid_search.return_value = [
|
|
||||||
[
|
|
||||||
{
|
|
||||||
"id": 0,
|
|
||||||
"distance": 0.1,
|
|
||||||
"entity": {"chunk_content": {"content": "mock chunk 1", "metadata": {"document_id": "doc1"}}},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 1,
|
|
||||||
"distance": 0.2,
|
|
||||||
"entity": {"chunk_content": {"content": "mock chunk 2", "metadata": {"document_id": "doc2"}}},
|
|
||||||
},
|
|
||||||
]
|
|
||||||
]
|
|
||||||
|
|
||||||
# Test hybrid search with weighted reranker
|
|
||||||
query_embedding = np.random.rand(embedding_dimension).astype(np.float32)
|
|
||||||
query_string = "test query"
|
|
||||||
response = await milvus_index.query_hybrid(
|
|
||||||
embedding=query_embedding,
|
|
||||||
query_string=query_string,
|
|
||||||
k=2,
|
|
||||||
score_threshold=0.0,
|
|
||||||
reranker_type="weighted",
|
|
||||||
reranker_params={"alpha": 0.7},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert isinstance(response, QueryChunksResponse)
|
|
||||||
assert len(response.chunks) == 2
|
|
||||||
assert len(response.scores) == 2
|
|
||||||
|
|
||||||
# Verify hybrid search was called with correct parameters
|
|
||||||
mock_milvus_client.hybrid_search.assert_called_once()
|
|
||||||
call_args = mock_milvus_client.hybrid_search.call_args
|
|
||||||
ranker = call_args[1]["ranker"]
|
|
||||||
assert ranker is not None
|
|
||||||
|
|
||||||
|
|
||||||
async def test_query_hybrid_search_default_rrf(
|
|
||||||
milvus_index, sample_chunks, sample_embeddings, embedding_dimension, mock_milvus_client
|
|
||||||
):
|
|
||||||
"""Test hybrid search with default RRF reranker (no reranker_type specified)."""
|
|
||||||
mock_milvus_client.has_collection.return_value = True
|
|
||||||
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
|
|
||||||
|
|
||||||
# Mock hybrid search results
|
|
||||||
mock_milvus_client.hybrid_search.return_value = [
|
|
||||||
[
|
|
||||||
{
|
|
||||||
"id": 0,
|
|
||||||
"distance": 0.1,
|
|
||||||
"entity": {"chunk_content": {"content": "mock chunk 1", "metadata": {"document_id": "doc1"}}},
|
|
||||||
},
|
|
||||||
]
|
|
||||||
]
|
|
||||||
|
|
||||||
# Test hybrid search with default reranker (should be RRF)
|
|
||||||
query_embedding = np.random.rand(embedding_dimension).astype(np.float32)
|
|
||||||
query_string = "test query"
|
|
||||||
response = await milvus_index.query_hybrid(
|
|
||||||
embedding=query_embedding,
|
|
||||||
query_string=query_string,
|
|
||||||
k=1,
|
|
||||||
score_threshold=0.0,
|
|
||||||
reranker_type="unknown_type", # Should default to RRF
|
|
||||||
reranker_params=None, # Should use default impact_factor
|
|
||||||
)
|
|
||||||
|
|
||||||
assert isinstance(response, QueryChunksResponse)
|
|
||||||
assert len(response.chunks) == 1
|
|
||||||
|
|
||||||
# Verify hybrid search was called with RRF reranker
|
|
||||||
mock_milvus_client.hybrid_search.assert_called_once()
|
|
||||||
call_args = mock_milvus_client.hybrid_search.call_args
|
|
||||||
ranker = call_args[1]["ranker"]
|
|
||||||
assert ranker is not None
|
|
|
@ -1,138 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
from unittest.mock import patch
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from llama_stack.apis.vector_dbs import VectorDB
|
|
||||||
from llama_stack.providers.remote.vector_io.pgvector.pgvector import PGVectorIndex
|
|
||||||
|
|
||||||
PGVECTOR_PROVIDER = "pgvector"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
|
||||||
def loop():
|
|
||||||
return asyncio.new_event_loop()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def embedding_dimension():
|
|
||||||
"""Default embedding dimension for tests."""
|
|
||||||
return 384
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
async def pgvector_index(embedding_dimension, mock_psycopg2_connection):
|
|
||||||
"""Create a PGVectorIndex instance with mocked database connection."""
|
|
||||||
connection, cursor = mock_psycopg2_connection
|
|
||||||
|
|
||||||
vector_db = VectorDB(
|
|
||||||
identifier="test-vector-db",
|
|
||||||
embedding_model="test-model",
|
|
||||||
embedding_dimension=embedding_dimension,
|
|
||||||
provider_id=PGVECTOR_PROVIDER,
|
|
||||||
provider_resource_id=f"{PGVECTOR_PROVIDER}:test-vector-db",
|
|
||||||
)
|
|
||||||
|
|
||||||
with patch("llama_stack.providers.remote.vector_io.pgvector.pgvector.psycopg2"):
|
|
||||||
# Use explicit COSINE distance metric for consistent testing
|
|
||||||
index = PGVectorIndex(vector_db, embedding_dimension, connection, distance_metric="COSINE")
|
|
||||||
|
|
||||||
return index, cursor
|
|
||||||
|
|
||||||
|
|
||||||
class TestPGVectorIndex:
|
|
||||||
def test_distance_metric_validation(self, embedding_dimension, mock_psycopg2_connection):
|
|
||||||
connection, cursor = mock_psycopg2_connection
|
|
||||||
|
|
||||||
vector_db = VectorDB(
|
|
||||||
identifier="test-vector-db",
|
|
||||||
embedding_model="test-model",
|
|
||||||
embedding_dimension=embedding_dimension,
|
|
||||||
provider_id=PGVECTOR_PROVIDER,
|
|
||||||
provider_resource_id=f"{PGVECTOR_PROVIDER}:test-vector-db",
|
|
||||||
)
|
|
||||||
|
|
||||||
with patch("llama_stack.providers.remote.vector_io.pgvector.pgvector.psycopg2"):
|
|
||||||
index = PGVectorIndex(vector_db, embedding_dimension, connection, distance_metric="L2")
|
|
||||||
assert index.distance_metric == "L2"
|
|
||||||
with pytest.raises(ValueError, match="Distance metric 'INVALID' is not supported"):
|
|
||||||
PGVectorIndex(vector_db, embedding_dimension, connection, distance_metric="INVALID")
|
|
||||||
|
|
||||||
def test_get_pgvector_search_function(self, pgvector_index):
|
|
||||||
index, cursor = pgvector_index
|
|
||||||
supported_metrics = index.PGVECTOR_DISTANCE_METRIC_TO_SEARCH_FUNCTION
|
|
||||||
|
|
||||||
for metric, function in supported_metrics.items():
|
|
||||||
index.distance_metric = metric
|
|
||||||
assert index.get_pgvector_search_function() == function
|
|
||||||
|
|
||||||
def test_check_distance_metric_availability(self, pgvector_index):
|
|
||||||
index, cursor = pgvector_index
|
|
||||||
supported_metrics = index.PGVECTOR_DISTANCE_METRIC_TO_SEARCH_FUNCTION
|
|
||||||
|
|
||||||
for metric in supported_metrics:
|
|
||||||
index.check_distance_metric_availability(metric)
|
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="Distance metric 'INVALID' is not supported"):
|
|
||||||
index.check_distance_metric_availability("INVALID")
|
|
||||||
|
|
||||||
def test_constructor_invalid_distance_metric(self, embedding_dimension, mock_psycopg2_connection):
|
|
||||||
connection, cursor = mock_psycopg2_connection
|
|
||||||
|
|
||||||
vector_db = VectorDB(
|
|
||||||
identifier="test-vector-db",
|
|
||||||
embedding_model="test-model",
|
|
||||||
embedding_dimension=embedding_dimension,
|
|
||||||
provider_id=PGVECTOR_PROVIDER,
|
|
||||||
provider_resource_id=f"{PGVECTOR_PROVIDER}:test-vector-db",
|
|
||||||
)
|
|
||||||
|
|
||||||
with patch("llama_stack.providers.remote.vector_io.pgvector.pgvector.psycopg2"):
|
|
||||||
with pytest.raises(ValueError, match="Distance metric 'INVALID_METRIC' is not supported by PGVector"):
|
|
||||||
PGVectorIndex(vector_db, embedding_dimension, connection, distance_metric="INVALID_METRIC")
|
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="Supported metrics are:"):
|
|
||||||
PGVectorIndex(vector_db, embedding_dimension, connection, distance_metric="UNKNOWN")
|
|
||||||
|
|
||||||
try:
|
|
||||||
index = PGVectorIndex(vector_db, embedding_dimension, connection, distance_metric="COSINE")
|
|
||||||
assert index.distance_metric == "COSINE"
|
|
||||||
except ValueError:
|
|
||||||
pytest.fail("Valid distance metric 'COSINE' should not raise ValueError")
|
|
||||||
|
|
||||||
def test_constructor_all_supported_distance_metrics(self, embedding_dimension, mock_psycopg2_connection):
|
|
||||||
connection, cursor = mock_psycopg2_connection
|
|
||||||
|
|
||||||
vector_db = VectorDB(
|
|
||||||
identifier="test-vector-db",
|
|
||||||
embedding_model="test-model",
|
|
||||||
embedding_dimension=embedding_dimension,
|
|
||||||
provider_id=PGVECTOR_PROVIDER,
|
|
||||||
provider_resource_id=f"{PGVECTOR_PROVIDER}:test-vector-db",
|
|
||||||
)
|
|
||||||
|
|
||||||
supported_metrics = ["L2", "L1", "COSINE", "INNER_PRODUCT", "HAMMING", "JACCARD"]
|
|
||||||
|
|
||||||
with patch("llama_stack.providers.remote.vector_io.pgvector.pgvector.psycopg2"):
|
|
||||||
for metric in supported_metrics:
|
|
||||||
try:
|
|
||||||
index = PGVectorIndex(vector_db, embedding_dimension, connection, distance_metric=metric)
|
|
||||||
assert index.distance_metric == metric
|
|
||||||
|
|
||||||
expected_operators = {
|
|
||||||
"L2": "<->",
|
|
||||||
"L1": "<+>",
|
|
||||||
"COSINE": "<=>",
|
|
||||||
"INNER_PRODUCT": "<#>",
|
|
||||||
"HAMMING": "<~>",
|
|
||||||
"JACCARD": "<%>",
|
|
||||||
}
|
|
||||||
assert index.get_pgvector_search_function() == expected_operators[metric]
|
|
||||||
except Exception as e:
|
|
||||||
pytest.fail(f"Valid distance metric '{metric}' should not raise exception: {e}")
|
|
|
@ -19,7 +19,7 @@ from llama_stack.apis.vector_io import (
|
||||||
VectorStoreChunkingStrategyAuto,
|
VectorStoreChunkingStrategyAuto,
|
||||||
VectorStoreFileObject,
|
VectorStoreFileObject,
|
||||||
)
|
)
|
||||||
from llama_stack.providers.remote.vector_io.milvus.milvus import VECTOR_DBS_PREFIX
|
from llama_stack.providers.inline.vector_io.sqlite_vec.sqlite_vec import VECTOR_DBS_PREFIX
|
||||||
|
|
||||||
# This test is a unit test for the inline VectorIO providers. This should only contain
|
# This test is a unit test for the inline VectorIO providers. This should only contain
|
||||||
# tests which are specific to this class. More general (API-level) tests should be placed in
|
# tests which are specific to this class. More general (API-level) tests should be placed in
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue