feat: implement keyword and hybrid search for Weaviate provider

This commit is contained in:
ChristianZaccaria 2025-08-27 12:24:38 +01:00
parent a1301911e4
commit 4541b517c8
8 changed files with 476 additions and 25 deletions

View file

@ -10,7 +10,7 @@ import weaviate
import weaviate.classes as wvc import weaviate.classes as wvc
from numpy.typing import NDArray from numpy.typing import NDArray
from weaviate.classes.init import Auth from weaviate.classes.init import Auth
from weaviate.classes.query import Filter from weaviate.classes.query import Filter, HybridFusion
from llama_stack.apis.common.content_types import InterleavedContent from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.common.errors import VectorStoreNotFoundError from llama_stack.apis.common.errors import VectorStoreNotFoundError
@ -26,6 +26,7 @@ from llama_stack.providers.utils.memory.openai_vector_store_mixin import (
OpenAIVectorStoreMixin, OpenAIVectorStoreMixin,
) )
from llama_stack.providers.utils.memory.vector_store import ( from llama_stack.providers.utils.memory.vector_store import (
RERANKER_TYPE_RRF,
ChunkForDeletion, ChunkForDeletion,
EmbeddingIndex, EmbeddingIndex,
VectorDBWithIndex, VectorDBWithIndex,
@ -88,6 +89,9 @@ class WeaviateIndex(EmbeddingIndex):
collection.data.delete_many(where=Filter.by_property("chunk_id").contains_any(chunk_ids)) collection.data.delete_many(where=Filter.by_property("chunk_id").contains_any(chunk_ids))
async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse: async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
log.info(
f"WEAVIATE VECTOR SEARCH CALLED: embedding_shape={embedding.shape}, k={k}, threshold={score_threshold}"
)
sanitized_collection_name = sanitize_collection_name(self.collection_name, weaviate_format=True) sanitized_collection_name = sanitize_collection_name(self.collection_name, weaviate_format=True)
collection = self.client.collections.get(sanitized_collection_name) collection = self.client.collections.get(sanitized_collection_name)
@ -115,6 +119,7 @@ class WeaviateIndex(EmbeddingIndex):
chunks.append(chunk) chunks.append(chunk)
scores.append(score) scores.append(score)
log.info(f"WEAVIATE VECTOR SEARCH RESULTS: Found {len(chunks)} chunks with scores {scores}")
return QueryChunksResponse(chunks=chunks, scores=scores) return QueryChunksResponse(chunks=chunks, scores=scores)
async def delete(self, chunk_ids: list[str] | None = None) -> None: async def delete(self, chunk_ids: list[str] | None = None) -> None:
@ -136,7 +141,46 @@ class WeaviateIndex(EmbeddingIndex):
k: int, k: int,
score_threshold: float, score_threshold: float,
) -> QueryChunksResponse: ) -> QueryChunksResponse:
raise NotImplementedError("Keyword search is not supported in Weaviate") """
Performs BM25-based keyword search using Weaviate's built-in full-text search.
Args:
query_string: The text query for keyword search
k: Limit of number of results to return
score_threshold: Minimum similarity score threshold
Returns:
QueryChunksResponse with combined results
"""
log.info(f"WEAVIATE KEYWORD SEARCH CALLED: query='{query_string}', k={k}, threshold={score_threshold}")
sanitized_collection_name = sanitize_collection_name(self.collection_name, weaviate_format=True)
collection = self.client.collections.get(sanitized_collection_name)
# Perform BM25 keyword search on chunk_content field
results = collection.query.bm25(
query=query_string,
limit=k,
return_metadata=wvc.query.MetadataQuery(score=True),
)
chunks = []
scores = []
for doc in results.objects:
chunk_json = doc.properties["chunk_content"]
try:
chunk_dict = json.loads(chunk_json)
chunk = Chunk(**chunk_dict)
except Exception:
log.exception(f"Failed to parse document: {chunk_json}")
continue
score = doc.metadata.score if doc.metadata.score is not None else 0.0
if score < score_threshold:
continue
chunks.append(chunk)
scores.append(score)
log.info(f"WEAVIATE KEYWORD SEARCH RESULTS: Found {len(chunks)} chunks with scores {scores}.")
return QueryChunksResponse(chunks=chunks, scores=scores)
async def query_hybrid( async def query_hybrid(
self, self,
@ -147,7 +191,62 @@ class WeaviateIndex(EmbeddingIndex):
reranker_type: str, reranker_type: str,
reranker_params: dict[str, Any] | None = None, reranker_params: dict[str, Any] | None = None,
) -> QueryChunksResponse: ) -> QueryChunksResponse:
raise NotImplementedError("Hybrid search is not supported in Weaviate") """
Hybrid search combining vector similarity and keyword search using Weaviate's native hybrid search.
Args:
embedding: The query embedding vector
query_string: The text query for keyword search
k: Limit of number of results to return
score_threshold: Minimum similarity score threshold
reranker_type: Type of reranker to use ("rrf" or "normalized")
reranker_params: Parameters for the reranker
Returns:
QueryChunksResponse with combined results
"""
log.info(
f"WEAVIATE HYBRID SEARCH CALLED: query='{query_string}', embedding_shape={embedding.shape}, k={k}, threshold={score_threshold}, reranker={reranker_type}"
)
sanitized_collection_name = sanitize_collection_name(self.collection_name, weaviate_format=True)
collection = self.client.collections.get(sanitized_collection_name)
# Ranked (RRF) reranker fusion type
if reranker_type == RERANKER_TYPE_RRF:
rerank = HybridFusion.RANKED
# Relative score (Normalized) reranker fusion type
else:
rerank = HybridFusion.RELATIVE_SCORE
# Perform hybrid search using Weaviate's native hybrid search
results = collection.query.hybrid(
query=query_string,
alpha=0.5, # Range <0, 1>, where 0.5 will equally favor vector and keyword search
vector=embedding.tolist(),
limit=k,
fusion_type=rerank,
return_metadata=wvc.query.MetadataQuery(score=True),
)
chunks = []
scores = []
for doc in results.objects:
chunk_json = doc.properties["chunk_content"]
try:
chunk_dict = json.loads(chunk_json)
chunk = Chunk(**chunk_dict)
except Exception:
log.exception(f"Failed to parse document: {chunk_json}")
continue
score = doc.metadata.score if doc.metadata.score is not None else 0.0
if score < score_threshold:
continue
log.info(f"Document {chunk.metadata.get('document_id')} has score {score}")
chunks.append(chunk)
scores.append(score)
log.info(f"WEAVIATE HYBRID SEARCH RESULTS: Found {len(chunks)} chunks with scores {scores}")
return QueryChunksResponse(chunks=chunks, scores=scores)
class WeaviateVectorIOAdapter( class WeaviateVectorIOAdapter(

View file

@ -50,6 +50,7 @@ class ChunkForDeletion(BaseModel):
# Constants for reranker types # Constants for reranker types
RERANKER_TYPE_RRF = "rrf" RERANKER_TYPE_RRF = "rrf"
RERANKER_TYPE_WEIGHTED = "weighted" RERANKER_TYPE_WEIGHTED = "weighted"
RERANKER_TYPE_NORMALIZED = "normalized"
def parse_pdf(data: bytes) -> str: def parse_pdf(data: bytes) -> str:
@ -325,6 +326,8 @@ class VectorDBWithIndex:
weights = ranker.get("params", {}).get("weights", [0.5, 0.5]) weights = ranker.get("params", {}).get("weights", [0.5, 0.5])
reranker_type = RERANKER_TYPE_WEIGHTED reranker_type = RERANKER_TYPE_WEIGHTED
reranker_params = {"alpha": weights[0] if len(weights) > 0 else 0.5} reranker_params = {"alpha": weights[0] if len(weights) > 0 else 0.5}
elif strategy == "normalized":
reranker_type = RERANKER_TYPE_NORMALIZED
else: else:
reranker_type = RERANKER_TYPE_RRF reranker_type = RERANKER_TYPE_RRF
k_value = ranker.get("params", {}).get("k", 60.0) k_value = ranker.get("params", {}).get("k", 60.0)

View file

@ -25,8 +25,8 @@ classifiers = [
] ]
dependencies = [ dependencies = [
"aiohttp", "aiohttp",
"fastapi>=0.115.0,<1.0", # server "fastapi>=0.115.0,<1.0", # server
"fire", # for MCP in LLS client "fire", # for MCP in LLS client
"httpx", "httpx",
"huggingface-hub>=0.34.0,<1.0", "huggingface-hub>=0.34.0,<1.0",
"jinja2>=3.1.6", "jinja2>=3.1.6",
@ -43,12 +43,13 @@ dependencies = [
"tiktoken", "tiktoken",
"pillow", "pillow",
"h11>=0.16.0", "h11>=0.16.0",
"python-multipart>=0.0.20", # For fastapi Form "python-multipart>=0.0.20", # For fastapi Form
"uvicorn>=0.34.0", # server "uvicorn>=0.34.0", # server
"opentelemetry-sdk>=1.30.0", # server "opentelemetry-sdk>=1.30.0", # server
"opentelemetry-exporter-otlp-proto-http>=1.30.0", # server "opentelemetry-exporter-otlp-proto-http>=1.30.0", # server
"aiosqlite>=0.21.0", # server - for metadata store "aiosqlite>=0.21.0", # server - for metadata store
"asyncpg", # for metadata store "asyncpg", # for metadata store
"weaviate-client>=4.16.5",
] ]
[project.optional-dependencies] [project.optional-dependencies]

View file

@ -22,16 +22,16 @@ def skip_if_provider_doesnt_support_openai_vector_stores(client_with_models):
vector_io_providers = [p for p in client_with_models.providers.list() if p.api == "vector_io"] vector_io_providers = [p for p in client_with_models.providers.list() if p.api == "vector_io"]
for p in vector_io_providers: for p in vector_io_providers:
if p.provider_type in [ if p.provider_type in [
"inline::faiss",
"inline::sqlite-vec",
"inline::milvus",
"inline::chromadb", "inline::chromadb",
"remote::pgvector", "inline::faiss",
"remote::chromadb", "inline::milvus",
"remote::qdrant",
"inline::qdrant", "inline::qdrant",
"remote::weaviate", "inline::sqlite-vec",
"remote::chromadb",
"remote::milvus", "remote::milvus",
"remote::pgvector",
"remote::qdrant",
"remote::weaviate",
]: ]:
return return
@ -47,23 +47,25 @@ def skip_if_provider_doesnt_support_openai_vector_stores_search(client_with_mode
"inline::milvus", "inline::milvus",
"inline::chromadb", "inline::chromadb",
"inline::qdrant", "inline::qdrant",
"remote::pgvector",
"remote::chromadb", "remote::chromadb",
"remote::weaviate",
"remote::qdrant",
"remote::milvus", "remote::milvus",
"remote::pgvector",
"remote::qdrant",
"remote::weaviate",
], ],
"keyword": [ "keyword": [
"inline::milvus",
"inline::sqlite-vec", "inline::sqlite-vec",
"remote::milvus", "remote::milvus",
"inline::milvus",
"remote::pgvector", "remote::pgvector",
"remote::weaviate",
], ],
"hybrid": [ "hybrid": [
"inline::sqlite-vec",
"inline::milvus", "inline::milvus",
"inline::sqlite-vec",
"remote::milvus", "remote::milvus",
"remote::pgvector", "remote::pgvector",
"remote::weaviate",
], ],
} }
supported_providers = search_mode_support.get(search_mode, []) supported_providers = search_mode_support.get(search_mode, [])

View file

@ -26,13 +26,15 @@ from llama_stack.providers.remote.vector_io.milvus.milvus import MilvusIndex, Mi
from llama_stack.providers.remote.vector_io.pgvector.config import PGVectorVectorIOConfig from llama_stack.providers.remote.vector_io.pgvector.config import PGVectorVectorIOConfig
from llama_stack.providers.remote.vector_io.pgvector.pgvector import PGVectorIndex, PGVectorVectorIOAdapter from llama_stack.providers.remote.vector_io.pgvector.pgvector import PGVectorIndex, PGVectorVectorIOAdapter
from llama_stack.providers.remote.vector_io.qdrant.qdrant import QdrantVectorIOAdapter from llama_stack.providers.remote.vector_io.qdrant.qdrant import QdrantVectorIOAdapter
from llama_stack.providers.remote.vector_io.weaviate.config import WeaviateVectorIOConfig
from llama_stack.providers.remote.vector_io.weaviate.weaviate import WeaviateIndex, WeaviateVectorIOAdapter
EMBEDDING_DIMENSION = 384 EMBEDDING_DIMENSION = 384
COLLECTION_PREFIX = "test_collection" COLLECTION_PREFIX = "test_collection"
MILVUS_ALIAS = "test_milvus" MILVUS_ALIAS = "test_milvus"
@pytest.fixture(params=["milvus", "sqlite_vec", "faiss", "chroma", "pgvector"]) @pytest.fixture(params=["milvus", "sqlite_vec", "faiss", "chroma", "pgvector", "weaviate"])
def vector_provider(request): def vector_provider(request):
return request.param return request.param
@ -446,6 +448,75 @@ async def pgvector_vec_adapter(mock_inference_api, embedding_dimension):
yield adapter yield adapter
await adapter.shutdown() await adapter.shutdown()
def weaviate_vec_db_path():
return "localhost:8080"
@pytest.fixture
async def weaviate_vec_index(weaviate_vec_db_path, embedding_dimension):
import uuid
import weaviate
# Connect to local Weaviate instance
client = weaviate.connect_to_local(
host="localhost",
port=8080,
)
collection_name = f"{COLLECTION_PREFIX}_{uuid.uuid4()}"
index = WeaviateIndex(client=client, collection_name=collection_name)
# Create the collection for this test
import weaviate.classes as wvc
from weaviate.collections.classes.config import _CollectionConfig
from llama_stack.providers.utils.vector_io.vector_utils import sanitize_collection_name
sanitized_name = sanitize_collection_name(collection_name, weaviate_format=True)
collection_config = _CollectionConfig(
name=sanitized_name,
vectorizer_config=wvc.config.Configure.Vectorizer.none(),
properties=[
wvc.config.Property(
name="chunk_content",
data_type=wvc.config.DataType.TEXT,
),
],
)
if not client.collections.exists(sanitized_name):
client.collections.create_from_config(collection_config)
yield index
await index.delete()
client.close()
@pytest.fixture
async def weaviate_vec_adapter(weaviate_vec_db_path, mock_inference_api, embedding_dimension):
config = WeaviateVectorIOConfig(
weaviate_cluster_url=weaviate_vec_db_path,
weaviate_api_key=None,
kvstore=SqliteKVStoreConfig(),
)
adapter = WeaviateVectorIOAdapter(
config=config,
inference_api=mock_inference_api,
files_api=None,
)
collection_id = f"weaviate_test_collection_{random.randint(1, 1_000_000)}"
await adapter.initialize()
await adapter.register_vector_db(
VectorDB(
identifier=collection_id,
provider_id="test_provider",
embedding_model="test_model",
embedding_dimension=embedding_dimension,
)
)
adapter.test_collection_id = collection_id
yield adapter
await adapter.shutdown()
@pytest.fixture @pytest.fixture
@ -457,6 +528,7 @@ def vector_io_adapter(vector_provider, request):
"chroma": "chroma_vec_adapter", "chroma": "chroma_vec_adapter",
"qdrant": "qdrant_vec_adapter", "qdrant": "qdrant_vec_adapter",
"pgvector": "pgvector_vec_adapter", "pgvector": "pgvector_vec_adapter",
"weaviate": "weaviate_vec_adapter",
} }
return request.getfixturevalue(vector_provider_dict[vector_provider]) return request.getfixturevalue(vector_provider_dict[vector_provider])

View file

@ -23,13 +23,13 @@ pymilvus_mock.AnnSearchRequest = MagicMock
with patch.dict("sys.modules", {"pymilvus": pymilvus_mock}): with patch.dict("sys.modules", {"pymilvus": pymilvus_mock}):
from llama_stack.providers.remote.vector_io.milvus.milvus import MilvusIndex from llama_stack.providers.remote.vector_io.milvus.milvus import MilvusIndex
# This test is a unit test for the MilvusVectorIOAdapter class. This should only contain # This test is a unit test for the MilvusIndex class. This should only contain
# tests which are specific to this class. More general (API-level) tests should be placed in # tests which are specific to this class. More general (API-level) tests should be placed in
# tests/integration/vector_io/ # tests/integration/vector_io/
# #
# How to run this test: # How to run this test:
# #
# pytest tests/unit/providers/vector_io/test_milvus.py \ # pytest tests/unit/providers/vector_io/remote/test_milvus.py \
# -v -s --tb=short --disable-warnings --asyncio-mode=auto # -v -s --tb=short --disable-warnings --asyncio-mode=auto
MILVUS_PROVIDER = "milvus" MILVUS_PROVIDER = "milvus"
@ -324,3 +324,6 @@ async def test_query_hybrid_search_default_rrf(
call_args = mock_milvus_client.hybrid_search.call_args call_args = mock_milvus_client.hybrid_search.call_args
ranker = call_args[1]["ranker"] ranker = call_args[1]["ranker"]
assert ranker is not None assert ranker is not None
# TODO: Write tests for the MilvusVectorIOAdapter class.

View file

@ -0,0 +1,269 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from unittest.mock import ANY, MagicMock, patch
import numpy as np
import pytest
from llama_stack.apis.vector_io import QueryChunksResponse
# Mock the Weaviate client
weaviate_mock = MagicMock()
# Apply the mock before importing WeaviateIndex
with patch.dict("sys.modules", {"weaviate": weaviate_mock}):
from llama_stack.providers.remote.vector_io.weaviate.weaviate import WeaviateIndex
# This test is a unit test for the WeaviateIndex class. This should only contain
# tests which are specific to this class. More general (API-level) tests should be placed in
# tests/integration/vector_io/
#
# How to run this test:
#
# pytest tests/unit/providers/vector_io/remote/test_weaviate.py \
# -v -s --tb=short --disable-warnings --asyncio-mode=auto
WEAVIATE_PROVIDER = "weaviate"
@pytest.fixture
async def mock_weaviate_client() -> MagicMock:
"""Create a mock Weaviate client with common method behaviors."""
mock_client = MagicMock()
mock_collection = MagicMock()
# Mock collection data operations
mock_collection.data.insert_many.return_value = None
mock_collection.data.delete_many.return_value = None
# Mock collection search operations
mock_collection.query.near_vector.return_value = None
mock_collection.query.bm25.return_value = None
mock_collection.query.hybrid.return_value = None
mock_results = MagicMock()
mock_results.objects = [MagicMock(), MagicMock()]
mock_collection.query.near_vector.return_value = mock_results
# Mock client collection operations
mock_client.collections.get.return_value = mock_collection
mock_client.collections.exists.return_value = True
mock_client.collections.delete.return_value = None
# Mock client close operation
mock_client.close.return_value = None
return mock_client
@pytest.fixture
async def weaviate_index(mock_weaviate_client):
"""Create a WeaviateIndex with mocked client."""
index = WeaviateIndex(client=mock_weaviate_client, collection_name="Testcollection")
yield index
# No real cleanup needed since we're using mocks
async def test_add_chunks(weaviate_index, sample_chunks, sample_embeddings, mock_weaviate_client):
# Setup: Add chunks first
await weaviate_index.add_chunks(sample_chunks, sample_embeddings)
# Assert
mock_weaviate_client.collections.get.assert_called_once_with("Testcollection")
mock_weaviate_client.collections.get.return_value.data.insert_many.assert_called_once()
# Verify the insert call had the right number of chunks
data_objects, _ = mock_weaviate_client.collections.get.return_value.data.insert_many.call_args
assert len(data_objects[0]) == len(sample_chunks)
async def test_query_chunks_vector(
weaviate_index, sample_chunks, sample_embeddings, embedding_dimension, mock_weaviate_client
):
# Setup: Add chunks first
await weaviate_index.add_chunks(sample_chunks, sample_embeddings)
# Create mock objects that match Weaviate's response structure
mock_objects = []
for i, chunk in enumerate(sample_chunks[:2]): # Return first 2 chunks
mock_obj = MagicMock()
mock_obj.properties = {"chunk_content": chunk.model_dump_json()}
mock_obj.metadata.distance = 0.1 + i * 0.1 # Mock distances
mock_objects.append(mock_obj)
mock_results = MagicMock()
mock_results.objects = mock_objects
mock_weaviate_client.collections.get.return_value.query.near_vector.return_value = mock_results
# Test vector search
query_embedding = np.random.rand(embedding_dimension).astype(np.float32)
response = await weaviate_index.query_vector(query_embedding, k=2, score_threshold=0.0)
assert isinstance(response, QueryChunksResponse)
assert len(response.chunks) == 2
assert len(response.scores) == 2
mock_weaviate_client.collections.get.return_value.query.near_vector.assert_called_once_with(
near_vector=query_embedding.tolist(),
limit=2,
return_metadata=ANY,
)
async def test_query_chunks_keyword_search(weaviate_index, sample_chunks, sample_embeddings, mock_weaviate_client):
await weaviate_index.add_chunks(sample_chunks, sample_embeddings)
# Find chunks that contain "Sentence 5"
matching_chunks = [chunk for chunk in sample_chunks if "Sentence 5" in chunk.content]
# Create mock objects that match Weaviate's BM25 response structure
# Return the first 2 matching chunks
mock_objects = []
for i, chunk in enumerate(matching_chunks[:2]):
mock_obj = MagicMock()
mock_obj.properties = {"chunk_content": chunk.model_dump_json()}
mock_obj.metadata.score = 0.9 - i * 0.1
mock_objects.append(mock_obj)
mock_results = MagicMock()
mock_results.objects = mock_objects
mock_weaviate_client.collections.get.return_value.query.bm25.return_value = mock_results
# Test keyword search
query_string = "Sentence 5"
response = await weaviate_index.query_keyword(query_string=query_string, k=2, score_threshold=0.0)
assert isinstance(response, QueryChunksResponse)
assert len(response.chunks) == 2
assert len(response.scores) == 2
# Verify that the returned chunks contain the search term
assert all("Sentence 5" in chunk.content for chunk in response.chunks)
mock_weaviate_client.collections.get.return_value.query.bm25.assert_called_once_with(
query=query_string,
limit=2,
return_metadata=ANY,
)
async def test_delete_collection(weaviate_index, mock_weaviate_client):
# Test collection deletion (when chunk_ids is None, it deletes the entire collection)
mock_weaviate_client.collections.exists.return_value = True
await weaviate_index.delete()
mock_weaviate_client.collections.exists.assert_called_once_with("Testcollection")
mock_weaviate_client.collections.delete.assert_called_once_with("Testcollection")
async def test_delete_chunks(weaviate_index, mock_weaviate_client):
# Test deleting specific chunks using ChunkForDeletion objects
from llama_stack.providers.utils.memory.vector_store import ChunkForDeletion
chunks_for_deletion = [
ChunkForDeletion(chunk_id="chunk-1", document_id="doc-1"),
ChunkForDeletion(chunk_id="chunk-2", document_id="doc-1"),
ChunkForDeletion(chunk_id="chunk-3", document_id="doc-2"),
]
await weaviate_index.delete_chunks(chunks_for_deletion)
mock_weaviate_client.collections.get.assert_called_once_with("Testcollection")
mock_weaviate_client.collections.get.return_value.data.delete_many.assert_called_once()
async def test_query_hybrid_rrf(
weaviate_index, sample_chunks, sample_embeddings, embedding_dimension, mock_weaviate_client
):
# Test hybrid search with RRF reranking
from weaviate.classes.query import HybridFusion
from llama_stack.providers.utils.memory.vector_store import RERANKER_TYPE_RRF
await weaviate_index.add_chunks(sample_chunks, sample_embeddings)
# Find chunks that contain "Sentence 5"
matching_chunks = [chunk for chunk in sample_chunks if "Sentence 5" in chunk.content]
# Create mock objects for hybrid search response
mock_objects = []
for i, chunk in enumerate(matching_chunks[:2]):
mock_obj = MagicMock()
mock_obj.properties = {"chunk_content": chunk.model_dump_json()}
mock_obj.metadata.score = 0.85 + i * 0.05
mock_objects.append(mock_obj)
mock_results = MagicMock()
mock_results.objects = mock_objects
mock_weaviate_client.collections.get.return_value.query.hybrid.return_value = mock_results
# Test hybrid search with RRF
query_string = "Sentence 5"
query_embedding = np.random.rand(embedding_dimension).astype(np.float32)
response = await weaviate_index.query_hybrid(
embedding=query_embedding, query_string=query_string, k=2, score_threshold=0.0, reranker_type=RERANKER_TYPE_RRF
)
assert isinstance(response, QueryChunksResponse)
assert len(response.chunks) == 2
assert len(response.scores) == 2
assert all("Sentence 5" in chunk.content for chunk in response.chunks)
# Verify the hybrid method was called with correct parameters
mock_weaviate_client.collections.get.return_value.query.hybrid.assert_called_once_with(
query=query_string,
alpha=0.5,
vector=query_embedding.tolist(),
limit=2,
fusion_type=HybridFusion.RANKED,
return_metadata=ANY,
)
async def test_query_hybrid_normalized(
weaviate_index, sample_chunks, sample_embeddings, embedding_dimension, mock_weaviate_client
):
from weaviate.classes.query import HybridFusion
await weaviate_index.add_chunks(sample_chunks, sample_embeddings)
# Find chunks that contain "Sentence 3" for different results
matching_chunks = [chunk for chunk in sample_chunks if "Sentence 3" in chunk.content]
# Create mock objects for hybrid search response
mock_objects = []
for i, chunk in enumerate(matching_chunks[:2]):
mock_obj = MagicMock()
mock_obj.properties = {"chunk_content": chunk.model_dump_json()}
mock_obj.metadata.score = 0.75 + i * 0.1 # Mock hybrid scores
mock_objects.append(mock_obj)
mock_results = MagicMock()
mock_results.objects = mock_objects
mock_weaviate_client.collections.get.return_value.query.hybrid.return_value = mock_results
# Test hybrid search with normalized reranking
query_string = "Sentence 3"
query_embedding = np.random.rand(embedding_dimension).astype(np.float32)
response = await weaviate_index.query_hybrid(
embedding=query_embedding, query_string=query_string, k=2, score_threshold=0.0, reranker_type="normalized"
)
assert isinstance(response, QueryChunksResponse)
assert len(response.chunks) == 2
assert len(response.scores) == 2
assert all("Sentence 3" in chunk.content for chunk in response.chunks)
# Verify the hybrid method was called with correct parameters
mock_weaviate_client.collections.get.return_value.query.hybrid.assert_called_once_with(
query=query_string,
alpha=0.5,
vector=query_embedding.tolist(),
limit=2,
fusion_type=HybridFusion.RELATIVE_SCORE,
return_metadata=ANY,
)
# TODO: Write tests for the WeaviateVectorIOAdapter class.

2
uv.lock generated
View file

@ -1777,6 +1777,7 @@ dependencies = [
{ name = "termcolor" }, { name = "termcolor" },
{ name = "tiktoken" }, { name = "tiktoken" },
{ name = "uvicorn" }, { name = "uvicorn" },
{ name = "weaviate-client" },
] ]
[package.optional-dependencies] [package.optional-dependencies]
@ -1904,6 +1905,7 @@ requires-dist = [
{ name = "termcolor" }, { name = "termcolor" },
{ name = "tiktoken" }, { name = "tiktoken" },
{ name = "uvicorn", specifier = ">=0.34.0" }, { name = "uvicorn", specifier = ">=0.34.0" },
{ name = "weaviate-client", specifier = ">=4.16.5" },
] ]
provides-extras = ["ui"] provides-extras = ["ui"]