diff --git a/llama_stack/providers/remote/vector_io/weaviate/weaviate.py b/llama_stack/providers/remote/vector_io/weaviate/weaviate.py index 59b6bf124..06bfdf397 100644 --- a/llama_stack/providers/remote/vector_io/weaviate/weaviate.py +++ b/llama_stack/providers/remote/vector_io/weaviate/weaviate.py @@ -10,7 +10,7 @@ import weaviate import weaviate.classes as wvc from numpy.typing import NDArray from weaviate.classes.init import Auth -from weaviate.classes.query import Filter +from weaviate.classes.query import Filter, HybridFusion from llama_stack.apis.common.content_types import InterleavedContent from llama_stack.apis.common.errors import VectorStoreNotFoundError @@ -26,6 +26,7 @@ from llama_stack.providers.utils.memory.openai_vector_store_mixin import ( OpenAIVectorStoreMixin, ) from llama_stack.providers.utils.memory.vector_store import ( + RERANKER_TYPE_RRF, ChunkForDeletion, EmbeddingIndex, VectorDBWithIndex, @@ -88,6 +89,9 @@ class WeaviateIndex(EmbeddingIndex): collection.data.delete_many(where=Filter.by_property("chunk_id").contains_any(chunk_ids)) async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse: + log.info( + f"WEAVIATE VECTOR SEARCH CALLED: embedding_shape={embedding.shape}, k={k}, threshold={score_threshold}" + ) sanitized_collection_name = sanitize_collection_name(self.collection_name, weaviate_format=True) collection = self.client.collections.get(sanitized_collection_name) @@ -115,6 +119,7 @@ class WeaviateIndex(EmbeddingIndex): chunks.append(chunk) scores.append(score) + log.info(f"WEAVIATE VECTOR SEARCH RESULTS: Found {len(chunks)} chunks with scores {scores}") return QueryChunksResponse(chunks=chunks, scores=scores) async def delete(self, chunk_ids: list[str] | None = None) -> None: @@ -136,7 +141,46 @@ class WeaviateIndex(EmbeddingIndex): k: int, score_threshold: float, ) -> QueryChunksResponse: - raise NotImplementedError("Keyword search is not supported in Weaviate") + """ + Performs BM25-based keyword search using Weaviate's built-in full-text search. + Args: + query_string: The text query for keyword search + k: Limit of number of results to return + score_threshold: Minimum similarity score threshold + Returns: + QueryChunksResponse with combined results + """ + log.info(f"WEAVIATE KEYWORD SEARCH CALLED: query='{query_string}', k={k}, threshold={score_threshold}") + sanitized_collection_name = sanitize_collection_name(self.collection_name, weaviate_format=True) + collection = self.client.collections.get(sanitized_collection_name) + + # Perform BM25 keyword search on chunk_content field + results = collection.query.bm25( + query=query_string, + limit=k, + return_metadata=wvc.query.MetadataQuery(score=True), + ) + + chunks = [] + scores = [] + for doc in results.objects: + chunk_json = doc.properties["chunk_content"] + try: + chunk_dict = json.loads(chunk_json) + chunk = Chunk(**chunk_dict) + except Exception: + log.exception(f"Failed to parse document: {chunk_json}") + continue + + score = doc.metadata.score if doc.metadata.score is not None else 0.0 + if score < score_threshold: + continue + + chunks.append(chunk) + scores.append(score) + + log.info(f"WEAVIATE KEYWORD SEARCH RESULTS: Found {len(chunks)} chunks with scores {scores}.") + return QueryChunksResponse(chunks=chunks, scores=scores) async def query_hybrid( self, @@ -147,7 +191,62 @@ class WeaviateIndex(EmbeddingIndex): reranker_type: str, reranker_params: dict[str, Any] | None = None, ) -> QueryChunksResponse: - raise NotImplementedError("Hybrid search is not supported in Weaviate") + """ + Hybrid search combining vector similarity and keyword search using Weaviate's native hybrid search. + Args: + embedding: The query embedding vector + query_string: The text query for keyword search + k: Limit of number of results to return + score_threshold: Minimum similarity score threshold + reranker_type: Type of reranker to use ("rrf" or "normalized") + reranker_params: Parameters for the reranker + Returns: + QueryChunksResponse with combined results + """ + log.info( + f"WEAVIATE HYBRID SEARCH CALLED: query='{query_string}', embedding_shape={embedding.shape}, k={k}, threshold={score_threshold}, reranker={reranker_type}" + ) + sanitized_collection_name = sanitize_collection_name(self.collection_name, weaviate_format=True) + collection = self.client.collections.get(sanitized_collection_name) + + # Ranked (RRF) reranker fusion type + if reranker_type == RERANKER_TYPE_RRF: + rerank = HybridFusion.RANKED + # Relative score (Normalized) reranker fusion type + else: + rerank = HybridFusion.RELATIVE_SCORE + + # Perform hybrid search using Weaviate's native hybrid search + results = collection.query.hybrid( + query=query_string, + alpha=0.5, # Range <0, 1>, where 0.5 will equally favor vector and keyword search + vector=embedding.tolist(), + limit=k, + fusion_type=rerank, + return_metadata=wvc.query.MetadataQuery(score=True), + ) + + chunks = [] + scores = [] + for doc in results.objects: + chunk_json = doc.properties["chunk_content"] + try: + chunk_dict = json.loads(chunk_json) + chunk = Chunk(**chunk_dict) + except Exception: + log.exception(f"Failed to parse document: {chunk_json}") + continue + + score = doc.metadata.score if doc.metadata.score is not None else 0.0 + if score < score_threshold: + continue + + log.info(f"Document {chunk.metadata.get('document_id')} has score {score}") + chunks.append(chunk) + scores.append(score) + + log.info(f"WEAVIATE HYBRID SEARCH RESULTS: Found {len(chunks)} chunks with scores {scores}") + return QueryChunksResponse(chunks=chunks, scores=scores) class WeaviateVectorIOAdapter( diff --git a/llama_stack/providers/utils/memory/vector_store.py b/llama_stack/providers/utils/memory/vector_store.py index aaa470970..857fbe910 100644 --- a/llama_stack/providers/utils/memory/vector_store.py +++ b/llama_stack/providers/utils/memory/vector_store.py @@ -50,6 +50,7 @@ class ChunkForDeletion(BaseModel): # Constants for reranker types RERANKER_TYPE_RRF = "rrf" RERANKER_TYPE_WEIGHTED = "weighted" +RERANKER_TYPE_NORMALIZED = "normalized" def parse_pdf(data: bytes) -> str: @@ -325,6 +326,8 @@ class VectorDBWithIndex: weights = ranker.get("params", {}).get("weights", [0.5, 0.5]) reranker_type = RERANKER_TYPE_WEIGHTED reranker_params = {"alpha": weights[0] if len(weights) > 0 else 0.5} + elif strategy == "normalized": + reranker_type = RERANKER_TYPE_NORMALIZED else: reranker_type = RERANKER_TYPE_RRF k_value = ranker.get("params", {}).get("k", 60.0) diff --git a/pyproject.toml b/pyproject.toml index ecbd8991a..ba5e082c0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,8 +25,8 @@ classifiers = [ ] dependencies = [ "aiohttp", - "fastapi>=0.115.0,<1.0", # server - "fire", # for MCP in LLS client + "fastapi>=0.115.0,<1.0", # server + "fire", # for MCP in LLS client "httpx", "huggingface-hub>=0.34.0,<1.0", "jinja2>=3.1.6", @@ -43,12 +43,13 @@ dependencies = [ "tiktoken", "pillow", "h11>=0.16.0", - "python-multipart>=0.0.20", # For fastapi Form - "uvicorn>=0.34.0", # server - "opentelemetry-sdk>=1.30.0", # server + "python-multipart>=0.0.20", # For fastapi Form + "uvicorn>=0.34.0", # server + "opentelemetry-sdk>=1.30.0", # server "opentelemetry-exporter-otlp-proto-http>=1.30.0", # server - "aiosqlite>=0.21.0", # server - for metadata store - "asyncpg", # for metadata store + "aiosqlite>=0.21.0", # server - for metadata store + "asyncpg", # for metadata store + "weaviate-client>=4.16.5", ] [project.optional-dependencies] diff --git a/tests/integration/vector_io/test_openai_vector_stores.py b/tests/integration/vector_io/test_openai_vector_stores.py index c67036eab..0c60acd27 100644 --- a/tests/integration/vector_io/test_openai_vector_stores.py +++ b/tests/integration/vector_io/test_openai_vector_stores.py @@ -22,16 +22,16 @@ def skip_if_provider_doesnt_support_openai_vector_stores(client_with_models): vector_io_providers = [p for p in client_with_models.providers.list() if p.api == "vector_io"] for p in vector_io_providers: if p.provider_type in [ - "inline::faiss", - "inline::sqlite-vec", - "inline::milvus", "inline::chromadb", - "remote::pgvector", - "remote::chromadb", - "remote::qdrant", + "inline::faiss", + "inline::milvus", "inline::qdrant", - "remote::weaviate", + "inline::sqlite-vec", + "remote::chromadb", "remote::milvus", + "remote::pgvector", + "remote::qdrant", + "remote::weaviate", ]: return @@ -47,23 +47,25 @@ def skip_if_provider_doesnt_support_openai_vector_stores_search(client_with_mode "inline::milvus", "inline::chromadb", "inline::qdrant", - "remote::pgvector", "remote::chromadb", - "remote::weaviate", - "remote::qdrant", "remote::milvus", + "remote::pgvector", + "remote::qdrant", + "remote::weaviate", ], "keyword": [ + "inline::milvus", "inline::sqlite-vec", "remote::milvus", - "inline::milvus", "remote::pgvector", + "remote::weaviate", ], "hybrid": [ - "inline::sqlite-vec", "inline::milvus", + "inline::sqlite-vec", "remote::milvus", "remote::pgvector", + "remote::weaviate", ], } supported_providers = search_mode_support.get(search_mode, []) diff --git a/tests/unit/providers/vector_io/conftest.py b/tests/unit/providers/vector_io/conftest.py index 91bddd037..e97cc0822 100644 --- a/tests/unit/providers/vector_io/conftest.py +++ b/tests/unit/providers/vector_io/conftest.py @@ -26,13 +26,15 @@ from llama_stack.providers.remote.vector_io.milvus.milvus import MilvusIndex, Mi from llama_stack.providers.remote.vector_io.pgvector.config import PGVectorVectorIOConfig from llama_stack.providers.remote.vector_io.pgvector.pgvector import PGVectorIndex, PGVectorVectorIOAdapter from llama_stack.providers.remote.vector_io.qdrant.qdrant import QdrantVectorIOAdapter +from llama_stack.providers.remote.vector_io.weaviate.config import WeaviateVectorIOConfig +from llama_stack.providers.remote.vector_io.weaviate.weaviate import WeaviateIndex, WeaviateVectorIOAdapter EMBEDDING_DIMENSION = 384 COLLECTION_PREFIX = "test_collection" MILVUS_ALIAS = "test_milvus" -@pytest.fixture(params=["milvus", "sqlite_vec", "faiss", "chroma", "pgvector"]) +@pytest.fixture(params=["milvus", "sqlite_vec", "faiss", "chroma", "pgvector", "weaviate"]) def vector_provider(request): return request.param @@ -446,6 +448,75 @@ async def pgvector_vec_adapter(mock_inference_api, embedding_dimension): yield adapter await adapter.shutdown() +def weaviate_vec_db_path(): + return "localhost:8080" + + +@pytest.fixture +async def weaviate_vec_index(weaviate_vec_db_path, embedding_dimension): + import uuid + + import weaviate + + # Connect to local Weaviate instance + client = weaviate.connect_to_local( + host="localhost", + port=8080, + ) + + collection_name = f"{COLLECTION_PREFIX}_{uuid.uuid4()}" + index = WeaviateIndex(client=client, collection_name=collection_name) + + # Create the collection for this test + import weaviate.classes as wvc + from weaviate.collections.classes.config import _CollectionConfig + + from llama_stack.providers.utils.vector_io.vector_utils import sanitize_collection_name + + sanitized_name = sanitize_collection_name(collection_name, weaviate_format=True) + collection_config = _CollectionConfig( + name=sanitized_name, + vectorizer_config=wvc.config.Configure.Vectorizer.none(), + properties=[ + wvc.config.Property( + name="chunk_content", + data_type=wvc.config.DataType.TEXT, + ), + ], + ) + if not client.collections.exists(sanitized_name): + client.collections.create_from_config(collection_config) + + yield index + await index.delete() + client.close() + + +@pytest.fixture +async def weaviate_vec_adapter(weaviate_vec_db_path, mock_inference_api, embedding_dimension): + config = WeaviateVectorIOConfig( + weaviate_cluster_url=weaviate_vec_db_path, + weaviate_api_key=None, + kvstore=SqliteKVStoreConfig(), + ) + adapter = WeaviateVectorIOAdapter( + config=config, + inference_api=mock_inference_api, + files_api=None, + ) + collection_id = f"weaviate_test_collection_{random.randint(1, 1_000_000)}" + await adapter.initialize() + await adapter.register_vector_db( + VectorDB( + identifier=collection_id, + provider_id="test_provider", + embedding_model="test_model", + embedding_dimension=embedding_dimension, + ) + ) + adapter.test_collection_id = collection_id + yield adapter + await adapter.shutdown() @pytest.fixture @@ -457,6 +528,7 @@ def vector_io_adapter(vector_provider, request): "chroma": "chroma_vec_adapter", "qdrant": "qdrant_vec_adapter", "pgvector": "pgvector_vec_adapter", + "weaviate": "weaviate_vec_adapter", } return request.getfixturevalue(vector_provider_dict[vector_provider]) diff --git a/tests/unit/providers/vector_io/remote/test_milvus.py b/tests/unit/providers/vector_io/remote/test_milvus.py index ca5f45fa2..c9af52ad8 100644 --- a/tests/unit/providers/vector_io/remote/test_milvus.py +++ b/tests/unit/providers/vector_io/remote/test_milvus.py @@ -23,13 +23,13 @@ pymilvus_mock.AnnSearchRequest = MagicMock with patch.dict("sys.modules", {"pymilvus": pymilvus_mock}): from llama_stack.providers.remote.vector_io.milvus.milvus import MilvusIndex -# This test is a unit test for the MilvusVectorIOAdapter class. This should only contain +# This test is a unit test for the MilvusIndex class. This should only contain # tests which are specific to this class. More general (API-level) tests should be placed in # tests/integration/vector_io/ # # How to run this test: # -# pytest tests/unit/providers/vector_io/test_milvus.py \ +# pytest tests/unit/providers/vector_io/remote/test_milvus.py \ # -v -s --tb=short --disable-warnings --asyncio-mode=auto MILVUS_PROVIDER = "milvus" @@ -324,3 +324,6 @@ async def test_query_hybrid_search_default_rrf( call_args = mock_milvus_client.hybrid_search.call_args ranker = call_args[1]["ranker"] assert ranker is not None + + +# TODO: Write tests for the MilvusVectorIOAdapter class. diff --git a/tests/unit/providers/vector_io/remote/test_weaviate.py b/tests/unit/providers/vector_io/remote/test_weaviate.py new file mode 100644 index 000000000..534b3b6b1 --- /dev/null +++ b/tests/unit/providers/vector_io/remote/test_weaviate.py @@ -0,0 +1,269 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from unittest.mock import ANY, MagicMock, patch + +import numpy as np +import pytest + +from llama_stack.apis.vector_io import QueryChunksResponse + +# Mock the Weaviate client +weaviate_mock = MagicMock() + +# Apply the mock before importing WeaviateIndex +with patch.dict("sys.modules", {"weaviate": weaviate_mock}): + from llama_stack.providers.remote.vector_io.weaviate.weaviate import WeaviateIndex + +# This test is a unit test for the WeaviateIndex class. This should only contain +# tests which are specific to this class. More general (API-level) tests should be placed in +# tests/integration/vector_io/ +# +# How to run this test: +# +# pytest tests/unit/providers/vector_io/remote/test_weaviate.py \ +# -v -s --tb=short --disable-warnings --asyncio-mode=auto + +WEAVIATE_PROVIDER = "weaviate" + + +@pytest.fixture +async def mock_weaviate_client() -> MagicMock: + """Create a mock Weaviate client with common method behaviors.""" + mock_client = MagicMock() + mock_collection = MagicMock() + + # Mock collection data operations + mock_collection.data.insert_many.return_value = None + mock_collection.data.delete_many.return_value = None + + # Mock collection search operations + mock_collection.query.near_vector.return_value = None + mock_collection.query.bm25.return_value = None + mock_collection.query.hybrid.return_value = None + mock_results = MagicMock() + mock_results.objects = [MagicMock(), MagicMock()] + mock_collection.query.near_vector.return_value = mock_results + + # Mock client collection operations + mock_client.collections.get.return_value = mock_collection + mock_client.collections.exists.return_value = True + mock_client.collections.delete.return_value = None + + # Mock client close operation + mock_client.close.return_value = None + + return mock_client + + +@pytest.fixture +async def weaviate_index(mock_weaviate_client): + """Create a WeaviateIndex with mocked client.""" + index = WeaviateIndex(client=mock_weaviate_client, collection_name="Testcollection") + yield index + # No real cleanup needed since we're using mocks + + +async def test_add_chunks(weaviate_index, sample_chunks, sample_embeddings, mock_weaviate_client): + # Setup: Add chunks first + await weaviate_index.add_chunks(sample_chunks, sample_embeddings) + + # Assert + mock_weaviate_client.collections.get.assert_called_once_with("Testcollection") + mock_weaviate_client.collections.get.return_value.data.insert_many.assert_called_once() + + # Verify the insert call had the right number of chunks + data_objects, _ = mock_weaviate_client.collections.get.return_value.data.insert_many.call_args + assert len(data_objects[0]) == len(sample_chunks) + + +async def test_query_chunks_vector( + weaviate_index, sample_chunks, sample_embeddings, embedding_dimension, mock_weaviate_client +): + # Setup: Add chunks first + await weaviate_index.add_chunks(sample_chunks, sample_embeddings) + + # Create mock objects that match Weaviate's response structure + mock_objects = [] + for i, chunk in enumerate(sample_chunks[:2]): # Return first 2 chunks + mock_obj = MagicMock() + mock_obj.properties = {"chunk_content": chunk.model_dump_json()} + mock_obj.metadata.distance = 0.1 + i * 0.1 # Mock distances + mock_objects.append(mock_obj) + + mock_results = MagicMock() + mock_results.objects = mock_objects + mock_weaviate_client.collections.get.return_value.query.near_vector.return_value = mock_results + + # Test vector search + query_embedding = np.random.rand(embedding_dimension).astype(np.float32) + response = await weaviate_index.query_vector(query_embedding, k=2, score_threshold=0.0) + + assert isinstance(response, QueryChunksResponse) + assert len(response.chunks) == 2 + assert len(response.scores) == 2 + mock_weaviate_client.collections.get.return_value.query.near_vector.assert_called_once_with( + near_vector=query_embedding.tolist(), + limit=2, + return_metadata=ANY, + ) + + +async def test_query_chunks_keyword_search(weaviate_index, sample_chunks, sample_embeddings, mock_weaviate_client): + await weaviate_index.add_chunks(sample_chunks, sample_embeddings) + + # Find chunks that contain "Sentence 5" + matching_chunks = [chunk for chunk in sample_chunks if "Sentence 5" in chunk.content] + + # Create mock objects that match Weaviate's BM25 response structure + # Return the first 2 matching chunks + mock_objects = [] + for i, chunk in enumerate(matching_chunks[:2]): + mock_obj = MagicMock() + mock_obj.properties = {"chunk_content": chunk.model_dump_json()} + mock_obj.metadata.score = 0.9 - i * 0.1 + mock_objects.append(mock_obj) + + mock_results = MagicMock() + mock_results.objects = mock_objects + mock_weaviate_client.collections.get.return_value.query.bm25.return_value = mock_results + + # Test keyword search + query_string = "Sentence 5" + response = await weaviate_index.query_keyword(query_string=query_string, k=2, score_threshold=0.0) + + assert isinstance(response, QueryChunksResponse) + assert len(response.chunks) == 2 + assert len(response.scores) == 2 + # Verify that the returned chunks contain the search term + assert all("Sentence 5" in chunk.content for chunk in response.chunks) + mock_weaviate_client.collections.get.return_value.query.bm25.assert_called_once_with( + query=query_string, + limit=2, + return_metadata=ANY, + ) + + +async def test_delete_collection(weaviate_index, mock_weaviate_client): + # Test collection deletion (when chunk_ids is None, it deletes the entire collection) + mock_weaviate_client.collections.exists.return_value = True + + await weaviate_index.delete() + + mock_weaviate_client.collections.exists.assert_called_once_with("Testcollection") + mock_weaviate_client.collections.delete.assert_called_once_with("Testcollection") + + +async def test_delete_chunks(weaviate_index, mock_weaviate_client): + # Test deleting specific chunks using ChunkForDeletion objects + from llama_stack.providers.utils.memory.vector_store import ChunkForDeletion + + chunks_for_deletion = [ + ChunkForDeletion(chunk_id="chunk-1", document_id="doc-1"), + ChunkForDeletion(chunk_id="chunk-2", document_id="doc-1"), + ChunkForDeletion(chunk_id="chunk-3", document_id="doc-2"), + ] + + await weaviate_index.delete_chunks(chunks_for_deletion) + + mock_weaviate_client.collections.get.assert_called_once_with("Testcollection") + mock_weaviate_client.collections.get.return_value.data.delete_many.assert_called_once() + + +async def test_query_hybrid_rrf( + weaviate_index, sample_chunks, sample_embeddings, embedding_dimension, mock_weaviate_client +): + # Test hybrid search with RRF reranking + from weaviate.classes.query import HybridFusion + + from llama_stack.providers.utils.memory.vector_store import RERANKER_TYPE_RRF + + await weaviate_index.add_chunks(sample_chunks, sample_embeddings) + + # Find chunks that contain "Sentence 5" + matching_chunks = [chunk for chunk in sample_chunks if "Sentence 5" in chunk.content] + + # Create mock objects for hybrid search response + mock_objects = [] + for i, chunk in enumerate(matching_chunks[:2]): + mock_obj = MagicMock() + mock_obj.properties = {"chunk_content": chunk.model_dump_json()} + mock_obj.metadata.score = 0.85 + i * 0.05 + mock_objects.append(mock_obj) + + mock_results = MagicMock() + mock_results.objects = mock_objects + mock_weaviate_client.collections.get.return_value.query.hybrid.return_value = mock_results + + # Test hybrid search with RRF + query_string = "Sentence 5" + query_embedding = np.random.rand(embedding_dimension).astype(np.float32) + response = await weaviate_index.query_hybrid( + embedding=query_embedding, query_string=query_string, k=2, score_threshold=0.0, reranker_type=RERANKER_TYPE_RRF + ) + + assert isinstance(response, QueryChunksResponse) + assert len(response.chunks) == 2 + assert len(response.scores) == 2 + assert all("Sentence 5" in chunk.content for chunk in response.chunks) + + # Verify the hybrid method was called with correct parameters + mock_weaviate_client.collections.get.return_value.query.hybrid.assert_called_once_with( + query=query_string, + alpha=0.5, + vector=query_embedding.tolist(), + limit=2, + fusion_type=HybridFusion.RANKED, + return_metadata=ANY, + ) + + +async def test_query_hybrid_normalized( + weaviate_index, sample_chunks, sample_embeddings, embedding_dimension, mock_weaviate_client +): + from weaviate.classes.query import HybridFusion + + await weaviate_index.add_chunks(sample_chunks, sample_embeddings) + + # Find chunks that contain "Sentence 3" for different results + matching_chunks = [chunk for chunk in sample_chunks if "Sentence 3" in chunk.content] + + # Create mock objects for hybrid search response + mock_objects = [] + for i, chunk in enumerate(matching_chunks[:2]): + mock_obj = MagicMock() + mock_obj.properties = {"chunk_content": chunk.model_dump_json()} + mock_obj.metadata.score = 0.75 + i * 0.1 # Mock hybrid scores + mock_objects.append(mock_obj) + + mock_results = MagicMock() + mock_results.objects = mock_objects + mock_weaviate_client.collections.get.return_value.query.hybrid.return_value = mock_results + + # Test hybrid search with normalized reranking + query_string = "Sentence 3" + query_embedding = np.random.rand(embedding_dimension).astype(np.float32) + response = await weaviate_index.query_hybrid( + embedding=query_embedding, query_string=query_string, k=2, score_threshold=0.0, reranker_type="normalized" + ) + + assert isinstance(response, QueryChunksResponse) + assert len(response.chunks) == 2 + assert len(response.scores) == 2 + assert all("Sentence 3" in chunk.content for chunk in response.chunks) + + # Verify the hybrid method was called with correct parameters + mock_weaviate_client.collections.get.return_value.query.hybrid.assert_called_once_with( + query=query_string, + alpha=0.5, + vector=query_embedding.tolist(), + limit=2, + fusion_type=HybridFusion.RELATIVE_SCORE, + return_metadata=ANY, + ) + + +# TODO: Write tests for the WeaviateVectorIOAdapter class. diff --git a/uv.lock b/uv.lock index 0833a9d77..bfa6d7d04 100644 --- a/uv.lock +++ b/uv.lock @@ -1777,6 +1777,7 @@ dependencies = [ { name = "termcolor" }, { name = "tiktoken" }, { name = "uvicorn" }, + { name = "weaviate-client" }, ] [package.optional-dependencies] @@ -1904,6 +1905,7 @@ requires-dist = [ { name = "termcolor" }, { name = "tiktoken" }, { name = "uvicorn", specifier = ">=0.34.0" }, + { name = "weaviate-client", specifier = ">=4.16.5" }, ] provides-extras = ["ui"]