From 980c7c244d7b97c204c48ed810c76eafead6af81 Mon Sep 17 00:00:00 2001 From: ChristianZaccaria Date: Wed, 3 Sep 2025 10:34:48 +0100 Subject: [PATCH] Remove Weaviate unit tests --- .../remote/vector_io/weaviate/weaviate.py | 54 ++-- tests/integration/vector_io/test_vector_io.py | 1 - tests/unit/providers/vector_io/conftest.py | 2 +- .../vector_io/remote/test_weaviate.py | 269 ------------------ 4 files changed, 34 insertions(+), 292 deletions(-) delete mode 100644 tests/unit/providers/vector_io/remote/test_weaviate.py diff --git a/llama_stack/providers/remote/vector_io/weaviate/weaviate.py b/llama_stack/providers/remote/vector_io/weaviate/weaviate.py index f17646d77..ea56705ef 100644 --- a/llama_stack/providers/remote/vector_io/weaviate/weaviate.py +++ b/llama_stack/providers/remote/vector_io/weaviate/weaviate.py @@ -104,11 +104,15 @@ class WeaviateIndex(EmbeddingIndex): sanitized_collection_name = sanitize_collection_name(self.collection_name, weaviate_format=True) collection = self.client.collections.get(sanitized_collection_name) - results = collection.query.near_vector( - near_vector=embedding.tolist(), - limit=k, - return_metadata=wvc.query.MetadataQuery(distance=True), - ) + try: + results = collection.query.near_vector( + near_vector=embedding.tolist(), + limit=k, + return_metadata=wvc.query.MetadataQuery(distance=True), + ) + except Exception as e: + log.error(f"Weaviate client vector search failed: {e}") + raise chunks = [] scores = [] @@ -123,8 +127,8 @@ class WeaviateIndex(EmbeddingIndex): if doc.metadata.distance is None: continue - # Convert cosine distance ∈ [0,2] → cosine similarity ∈ [-1,1] - score = 1.0 - float(doc.metadata.distance) + # Convert cosine distance ∈ [0,2] -> normalized cosine similarity ∈ [0,1] + score = 1.0 - (float(doc.metadata.distance) / 2.0) if score < score_threshold: continue @@ -167,11 +171,15 @@ class WeaviateIndex(EmbeddingIndex): collection = self.client.collections.get(sanitized_collection_name) # Perform BM25 keyword search on chunk_content field - results = collection.query.bm25( - query=query_string, - limit=k, - return_metadata=wvc.query.MetadataQuery(score=True), - ) + try: + results = collection.query.bm25( + query=query_string, + limit=k, + return_metadata=wvc.query.MetadataQuery(score=True), + ) + except Exception as e: + log.error(f"Weaviate client keyword search failed: {e}") + raise chunks = [] scores = [] @@ -229,14 +237,18 @@ class WeaviateIndex(EmbeddingIndex): rerank = HybridFusion.RELATIVE_SCORE # Perform hybrid search using Weaviate's native hybrid search - results = collection.query.hybrid( - query=query_string, - alpha=0.5, # Range <0, 1>, where 0.5 will equally favor vector and keyword search - vector=embedding.tolist(), - limit=k, - fusion_type=rerank, - return_metadata=wvc.query.MetadataQuery(score=True), - ) + try: + results = collection.query.hybrid( + query=query_string, + alpha=0.5, # Range <0, 1>, where 0.5 will equally favor vector and keyword search + vector=embedding.tolist(), + limit=k, + fusion_type=rerank, + return_metadata=wvc.query.MetadataQuery(score=True), + ) + except Exception as e: + log.error(f"Weaviate client hybrid search failed: {e}") + raise chunks = [] scores = [] @@ -283,7 +295,7 @@ class WeaviateVectorIOAdapter( self.openai_vector_stores: dict[str, dict[str, Any]] = {} self.metadata_collection_name = "openai_vector_stores_metadata" - def _get_client(self) -> weaviate.Client: + def _get_client(self) -> weaviate.WeaviateClient: if "localhost" in self.config.weaviate_cluster_url: log.info("using Weaviate locally in container") host, port = self.config.weaviate_cluster_url.split(":") diff --git a/tests/integration/vector_io/test_vector_io.py b/tests/integration/vector_io/test_vector_io.py index c6c7f3c6b..7bfe31dd6 100644 --- a/tests/integration/vector_io/test_vector_io.py +++ b/tests/integration/vector_io/test_vector_io.py @@ -140,7 +140,6 @@ def test_insert_chunks_with_precomputed_embeddings(client_with_empty_registry, e "inline::milvus": {"score_threshold": -1.0}, "inline::qdrant": {"score_threshold": -1.0}, "remote::qdrant": {"score_threshold": -1.0}, - "remote::weaviate": {"score_threshold": -1.0}, } vector_db_name = "test_precomputed_embeddings_db" register_response = client_with_empty_registry.vector_dbs.register( diff --git a/tests/unit/providers/vector_io/conftest.py b/tests/unit/providers/vector_io/conftest.py index 876e0401d..70ace695e 100644 --- a/tests/unit/providers/vector_io/conftest.py +++ b/tests/unit/providers/vector_io/conftest.py @@ -457,7 +457,7 @@ def weaviate_vec_db_path(tmp_path_factory): @pytest.fixture -async def weaviate_vec_index(weaviate_vec_db_path, embedding_dimension): +async def weaviate_vec_index(weaviate_vec_db_path): import pytest_socket import weaviate diff --git a/tests/unit/providers/vector_io/remote/test_weaviate.py b/tests/unit/providers/vector_io/remote/test_weaviate.py deleted file mode 100644 index 534b3b6b1..000000000 --- a/tests/unit/providers/vector_io/remote/test_weaviate.py +++ /dev/null @@ -1,269 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from unittest.mock import ANY, MagicMock, patch - -import numpy as np -import pytest - -from llama_stack.apis.vector_io import QueryChunksResponse - -# Mock the Weaviate client -weaviate_mock = MagicMock() - -# Apply the mock before importing WeaviateIndex -with patch.dict("sys.modules", {"weaviate": weaviate_mock}): - from llama_stack.providers.remote.vector_io.weaviate.weaviate import WeaviateIndex - -# This test is a unit test for the WeaviateIndex class. This should only contain -# tests which are specific to this class. More general (API-level) tests should be placed in -# tests/integration/vector_io/ -# -# How to run this test: -# -# pytest tests/unit/providers/vector_io/remote/test_weaviate.py \ -# -v -s --tb=short --disable-warnings --asyncio-mode=auto - -WEAVIATE_PROVIDER = "weaviate" - - -@pytest.fixture -async def mock_weaviate_client() -> MagicMock: - """Create a mock Weaviate client with common method behaviors.""" - mock_client = MagicMock() - mock_collection = MagicMock() - - # Mock collection data operations - mock_collection.data.insert_many.return_value = None - mock_collection.data.delete_many.return_value = None - - # Mock collection search operations - mock_collection.query.near_vector.return_value = None - mock_collection.query.bm25.return_value = None - mock_collection.query.hybrid.return_value = None - mock_results = MagicMock() - mock_results.objects = [MagicMock(), MagicMock()] - mock_collection.query.near_vector.return_value = mock_results - - # Mock client collection operations - mock_client.collections.get.return_value = mock_collection - mock_client.collections.exists.return_value = True - mock_client.collections.delete.return_value = None - - # Mock client close operation - mock_client.close.return_value = None - - return mock_client - - -@pytest.fixture -async def weaviate_index(mock_weaviate_client): - """Create a WeaviateIndex with mocked client.""" - index = WeaviateIndex(client=mock_weaviate_client, collection_name="Testcollection") - yield index - # No real cleanup needed since we're using mocks - - -async def test_add_chunks(weaviate_index, sample_chunks, sample_embeddings, mock_weaviate_client): - # Setup: Add chunks first - await weaviate_index.add_chunks(sample_chunks, sample_embeddings) - - # Assert - mock_weaviate_client.collections.get.assert_called_once_with("Testcollection") - mock_weaviate_client.collections.get.return_value.data.insert_many.assert_called_once() - - # Verify the insert call had the right number of chunks - data_objects, _ = mock_weaviate_client.collections.get.return_value.data.insert_many.call_args - assert len(data_objects[0]) == len(sample_chunks) - - -async def test_query_chunks_vector( - weaviate_index, sample_chunks, sample_embeddings, embedding_dimension, mock_weaviate_client -): - # Setup: Add chunks first - await weaviate_index.add_chunks(sample_chunks, sample_embeddings) - - # Create mock objects that match Weaviate's response structure - mock_objects = [] - for i, chunk in enumerate(sample_chunks[:2]): # Return first 2 chunks - mock_obj = MagicMock() - mock_obj.properties = {"chunk_content": chunk.model_dump_json()} - mock_obj.metadata.distance = 0.1 + i * 0.1 # Mock distances - mock_objects.append(mock_obj) - - mock_results = MagicMock() - mock_results.objects = mock_objects - mock_weaviate_client.collections.get.return_value.query.near_vector.return_value = mock_results - - # Test vector search - query_embedding = np.random.rand(embedding_dimension).astype(np.float32) - response = await weaviate_index.query_vector(query_embedding, k=2, score_threshold=0.0) - - assert isinstance(response, QueryChunksResponse) - assert len(response.chunks) == 2 - assert len(response.scores) == 2 - mock_weaviate_client.collections.get.return_value.query.near_vector.assert_called_once_with( - near_vector=query_embedding.tolist(), - limit=2, - return_metadata=ANY, - ) - - -async def test_query_chunks_keyword_search(weaviate_index, sample_chunks, sample_embeddings, mock_weaviate_client): - await weaviate_index.add_chunks(sample_chunks, sample_embeddings) - - # Find chunks that contain "Sentence 5" - matching_chunks = [chunk for chunk in sample_chunks if "Sentence 5" in chunk.content] - - # Create mock objects that match Weaviate's BM25 response structure - # Return the first 2 matching chunks - mock_objects = [] - for i, chunk in enumerate(matching_chunks[:2]): - mock_obj = MagicMock() - mock_obj.properties = {"chunk_content": chunk.model_dump_json()} - mock_obj.metadata.score = 0.9 - i * 0.1 - mock_objects.append(mock_obj) - - mock_results = MagicMock() - mock_results.objects = mock_objects - mock_weaviate_client.collections.get.return_value.query.bm25.return_value = mock_results - - # Test keyword search - query_string = "Sentence 5" - response = await weaviate_index.query_keyword(query_string=query_string, k=2, score_threshold=0.0) - - assert isinstance(response, QueryChunksResponse) - assert len(response.chunks) == 2 - assert len(response.scores) == 2 - # Verify that the returned chunks contain the search term - assert all("Sentence 5" in chunk.content for chunk in response.chunks) - mock_weaviate_client.collections.get.return_value.query.bm25.assert_called_once_with( - query=query_string, - limit=2, - return_metadata=ANY, - ) - - -async def test_delete_collection(weaviate_index, mock_weaviate_client): - # Test collection deletion (when chunk_ids is None, it deletes the entire collection) - mock_weaviate_client.collections.exists.return_value = True - - await weaviate_index.delete() - - mock_weaviate_client.collections.exists.assert_called_once_with("Testcollection") - mock_weaviate_client.collections.delete.assert_called_once_with("Testcollection") - - -async def test_delete_chunks(weaviate_index, mock_weaviate_client): - # Test deleting specific chunks using ChunkForDeletion objects - from llama_stack.providers.utils.memory.vector_store import ChunkForDeletion - - chunks_for_deletion = [ - ChunkForDeletion(chunk_id="chunk-1", document_id="doc-1"), - ChunkForDeletion(chunk_id="chunk-2", document_id="doc-1"), - ChunkForDeletion(chunk_id="chunk-3", document_id="doc-2"), - ] - - await weaviate_index.delete_chunks(chunks_for_deletion) - - mock_weaviate_client.collections.get.assert_called_once_with("Testcollection") - mock_weaviate_client.collections.get.return_value.data.delete_many.assert_called_once() - - -async def test_query_hybrid_rrf( - weaviate_index, sample_chunks, sample_embeddings, embedding_dimension, mock_weaviate_client -): - # Test hybrid search with RRF reranking - from weaviate.classes.query import HybridFusion - - from llama_stack.providers.utils.memory.vector_store import RERANKER_TYPE_RRF - - await weaviate_index.add_chunks(sample_chunks, sample_embeddings) - - # Find chunks that contain "Sentence 5" - matching_chunks = [chunk for chunk in sample_chunks if "Sentence 5" in chunk.content] - - # Create mock objects for hybrid search response - mock_objects = [] - for i, chunk in enumerate(matching_chunks[:2]): - mock_obj = MagicMock() - mock_obj.properties = {"chunk_content": chunk.model_dump_json()} - mock_obj.metadata.score = 0.85 + i * 0.05 - mock_objects.append(mock_obj) - - mock_results = MagicMock() - mock_results.objects = mock_objects - mock_weaviate_client.collections.get.return_value.query.hybrid.return_value = mock_results - - # Test hybrid search with RRF - query_string = "Sentence 5" - query_embedding = np.random.rand(embedding_dimension).astype(np.float32) - response = await weaviate_index.query_hybrid( - embedding=query_embedding, query_string=query_string, k=2, score_threshold=0.0, reranker_type=RERANKER_TYPE_RRF - ) - - assert isinstance(response, QueryChunksResponse) - assert len(response.chunks) == 2 - assert len(response.scores) == 2 - assert all("Sentence 5" in chunk.content for chunk in response.chunks) - - # Verify the hybrid method was called with correct parameters - mock_weaviate_client.collections.get.return_value.query.hybrid.assert_called_once_with( - query=query_string, - alpha=0.5, - vector=query_embedding.tolist(), - limit=2, - fusion_type=HybridFusion.RANKED, - return_metadata=ANY, - ) - - -async def test_query_hybrid_normalized( - weaviate_index, sample_chunks, sample_embeddings, embedding_dimension, mock_weaviate_client -): - from weaviate.classes.query import HybridFusion - - await weaviate_index.add_chunks(sample_chunks, sample_embeddings) - - # Find chunks that contain "Sentence 3" for different results - matching_chunks = [chunk for chunk in sample_chunks if "Sentence 3" in chunk.content] - - # Create mock objects for hybrid search response - mock_objects = [] - for i, chunk in enumerate(matching_chunks[:2]): - mock_obj = MagicMock() - mock_obj.properties = {"chunk_content": chunk.model_dump_json()} - mock_obj.metadata.score = 0.75 + i * 0.1 # Mock hybrid scores - mock_objects.append(mock_obj) - - mock_results = MagicMock() - mock_results.objects = mock_objects - mock_weaviate_client.collections.get.return_value.query.hybrid.return_value = mock_results - - # Test hybrid search with normalized reranking - query_string = "Sentence 3" - query_embedding = np.random.rand(embedding_dimension).astype(np.float32) - response = await weaviate_index.query_hybrid( - embedding=query_embedding, query_string=query_string, k=2, score_threshold=0.0, reranker_type="normalized" - ) - - assert isinstance(response, QueryChunksResponse) - assert len(response.chunks) == 2 - assert len(response.scores) == 2 - assert all("Sentence 3" in chunk.content for chunk in response.chunks) - - # Verify the hybrid method was called with correct parameters - mock_weaviate_client.collections.get.return_value.query.hybrid.assert_called_once_with( - query=query_string, - alpha=0.5, - vector=query_embedding.tolist(), - limit=2, - fusion_type=HybridFusion.RELATIVE_SCORE, - return_metadata=ANY, - ) - - -# TODO: Write tests for the WeaviateVectorIOAdapter class.