From bcdbb53be3137e47639abd0c9e94686bb5af498d Mon Sep 17 00:00:00 2001 From: Christian Zaccaria <73656840+ChristianZaccaria@users.noreply.github.com> Date: Fri, 3 Oct 2025 09:22:30 +0100 Subject: [PATCH] feat: implement keyword and hybrid search for Weaviate provider (#3264) # What does this PR do? - This PR implements keyword and hybrid search for Weaviate DB based on its inbuilt functions. - Added fixtures to conftest.py for Weaviate. - Enabled integration tests for remote Weaviate on all 3 search modes. Closes #3010 ## Test Plan Unit tests and integration tests should pass on this PR. --- llama_stack/providers/registry/vector_io.py | 2 +- .../remote/vector_io/weaviate/weaviate.py | 187 ++++++++++++++---- .../providers/utils/memory/vector_store.py | 3 + .../vector_io/test_openai_vector_stores.py | 26 +-- tests/integration/vector_io/test_vector_io.py | 2 +- tests/unit/providers/vector_io/conftest.py | 70 ++++++- 6 files changed, 242 insertions(+), 48 deletions(-) diff --git a/llama_stack/providers/registry/vector_io.py b/llama_stack/providers/registry/vector_io.py index 9816838e7..ebab7aaf9 100644 --- a/llama_stack/providers/registry/vector_io.py +++ b/llama_stack/providers/registry/vector_io.py @@ -500,7 +500,7 @@ See [PGVector's documentation](https://github.com/pgvector/pgvector) for more de api=Api.vector_io, adapter_type="weaviate", provider_type="remote::weaviate", - pip_packages=["weaviate-client"], + pip_packages=["weaviate-client>=4.16.5"], module="llama_stack.providers.remote.vector_io.weaviate", config_class="llama_stack.providers.remote.vector_io.weaviate.WeaviateVectorIOConfig", provider_data_validator="llama_stack.providers.remote.vector_io.weaviate.WeaviateRequestProviderData", diff --git a/llama_stack/providers/remote/vector_io/weaviate/weaviate.py b/llama_stack/providers/remote/vector_io/weaviate/weaviate.py index 59b6bf124..02d132106 100644 --- a/llama_stack/providers/remote/vector_io/weaviate/weaviate.py +++ b/llama_stack/providers/remote/vector_io/weaviate/weaviate.py @@ -10,7 +10,7 @@ import weaviate import weaviate.classes as wvc from numpy.typing import NDArray from weaviate.classes.init import Auth -from weaviate.classes.query import Filter +from weaviate.classes.query import Filter, HybridFusion from llama_stack.apis.common.content_types import InterleavedContent from llama_stack.apis.common.errors import VectorStoreNotFoundError @@ -26,6 +26,7 @@ from llama_stack.providers.utils.memory.openai_vector_store_mixin import ( OpenAIVectorStoreMixin, ) from llama_stack.providers.utils.memory.vector_store import ( + RERANKER_TYPE_RRF, ChunkForDeletion, EmbeddingIndex, VectorDBWithIndex, @@ -47,7 +48,7 @@ OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX = f"openai_vector_stores_files_conten class WeaviateIndex(EmbeddingIndex): def __init__( self, - client: weaviate.Client, + client: weaviate.WeaviateClient, collection_name: str, kvstore: KVStore | None = None, ): @@ -64,14 +65,14 @@ class WeaviateIndex(EmbeddingIndex): ) data_objects = [] - for i, chunk in enumerate(chunks): + for chunk, embedding in zip(chunks, embeddings, strict=False): data_objects.append( wvc.data.DataObject( properties={ "chunk_id": chunk.chunk_id, "chunk_content": chunk.model_dump_json(), }, - vector=embeddings[i].tolist(), + vector=embedding.tolist(), ) ) @@ -88,14 +89,30 @@ class WeaviateIndex(EmbeddingIndex): collection.data.delete_many(where=Filter.by_property("chunk_id").contains_any(chunk_ids)) async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse: + """ + Performs vector search using Weaviate's built-in vector search. + Args: + embedding: The query embedding vector + k: Limit of number of results to return + score_threshold: Minimum similarity score threshold + Returns: + QueryChunksResponse with chunks and scores. + """ + log.debug( + f"WEAVIATE VECTOR SEARCH CALLED: embedding_shape={embedding.shape}, k={k}, threshold={score_threshold}" + ) sanitized_collection_name = sanitize_collection_name(self.collection_name, weaviate_format=True) collection = self.client.collections.get(sanitized_collection_name) - results = collection.query.near_vector( - near_vector=embedding.tolist(), - limit=k, - return_metadata=wvc.query.MetadataQuery(distance=True), - ) + try: + results = collection.query.near_vector( + near_vector=embedding.tolist(), + limit=k, + return_metadata=wvc.query.MetadataQuery(distance=True), + ) + except Exception as e: + log.error(f"Weaviate client vector search failed: {e}") + raise chunks = [] scores = [] @@ -108,13 +125,17 @@ class WeaviateIndex(EmbeddingIndex): log.exception(f"Failed to parse document: {chunk_json}") continue - score = 1.0 / doc.metadata.distance if doc.metadata.distance != 0 else float("inf") + if doc.metadata.distance is None: + continue + # Convert cosine distance ∈ [0,2] -> normalized cosine similarity ∈ [0,1] + score = 1.0 - (float(doc.metadata.distance) / 2.0) if score < score_threshold: continue chunks.append(chunk) scores.append(score) + log.debug(f"WEAVIATE VECTOR SEARCH RESULTS: Found {len(chunks)} chunks with scores {scores}") return QueryChunksResponse(chunks=chunks, scores=scores) async def delete(self, chunk_ids: list[str] | None = None) -> None: @@ -136,7 +157,50 @@ class WeaviateIndex(EmbeddingIndex): k: int, score_threshold: float, ) -> QueryChunksResponse: - raise NotImplementedError("Keyword search is not supported in Weaviate") + """ + Performs BM25-based keyword search using Weaviate's built-in full-text search. + Args: + query_string: The text query for keyword search + k: Limit of number of results to return + score_threshold: Minimum similarity score threshold + Returns: + QueryChunksResponse with chunks and scores + """ + log.debug(f"WEAVIATE KEYWORD SEARCH CALLED: query='{query_string}', k={k}, threshold={score_threshold}") + sanitized_collection_name = sanitize_collection_name(self.collection_name, weaviate_format=True) + collection = self.client.collections.get(sanitized_collection_name) + + # Perform BM25 keyword search on chunk_content field + try: + results = collection.query.bm25( + query=query_string, + limit=k, + return_metadata=wvc.query.MetadataQuery(score=True), + ) + except Exception as e: + log.error(f"Weaviate client keyword search failed: {e}") + raise + + chunks = [] + scores = [] + for doc in results.objects: + chunk_json = doc.properties["chunk_content"] + try: + chunk_dict = json.loads(chunk_json) + chunk = Chunk(**chunk_dict) + except Exception: + log.exception(f"Failed to parse document: {chunk_json}") + continue + + score = doc.metadata.score if doc.metadata.score is not None else 0.0 + if score < score_threshold: + continue + + chunks.append(chunk) + scores.append(score) + + log.debug(f"WEAVIATE KEYWORD SEARCH RESULTS: Found {len(chunks)} chunks with scores {scores}.") + return QueryChunksResponse(chunks=chunks, scores=scores) async def query_hybrid( self, @@ -147,7 +211,65 @@ class WeaviateIndex(EmbeddingIndex): reranker_type: str, reranker_params: dict[str, Any] | None = None, ) -> QueryChunksResponse: - raise NotImplementedError("Hybrid search is not supported in Weaviate") + """ + Hybrid search combining vector similarity and keyword search using Weaviate's native hybrid search. + Args: + embedding: The query embedding vector + query_string: The text query for keyword search + k: Limit of number of results to return + score_threshold: Minimum similarity score threshold + reranker_type: Type of reranker to use ("rrf" or "normalized") + reranker_params: Parameters for the reranker + Returns: + QueryChunksResponse with combined results + """ + log.debug( + f"WEAVIATE HYBRID SEARCH CALLED: query='{query_string}', embedding_shape={embedding.shape}, k={k}, threshold={score_threshold}, reranker={reranker_type}" + ) + sanitized_collection_name = sanitize_collection_name(self.collection_name, weaviate_format=True) + collection = self.client.collections.get(sanitized_collection_name) + + # Ranked (RRF) reranker fusion type + if reranker_type == RERANKER_TYPE_RRF: + rerank = HybridFusion.RANKED + # Relative score (Normalized) reranker fusion type + else: + rerank = HybridFusion.RELATIVE_SCORE + + # Perform hybrid search using Weaviate's native hybrid search + try: + results = collection.query.hybrid( + query=query_string, + alpha=0.5, # Range <0, 1>, where 0.5 will equally favor vector and keyword search + vector=embedding.tolist(), + limit=k, + fusion_type=rerank, + return_metadata=wvc.query.MetadataQuery(score=True), + ) + except Exception as e: + log.error(f"Weaviate client hybrid search failed: {e}") + raise + + chunks = [] + scores = [] + for doc in results.objects: + chunk_json = doc.properties["chunk_content"] + try: + chunk_dict = json.loads(chunk_json) + chunk = Chunk(**chunk_dict) + except Exception: + log.exception(f"Failed to parse document: {chunk_json}") + continue + + score = doc.metadata.score if doc.metadata.score is not None else 0.0 + if score < score_threshold: + continue + + chunks.append(chunk) + scores.append(score) + + log.debug(f"WEAVIATE HYBRID SEARCH RESULTS: Found {len(chunks)} chunks with scores {scores}") + return QueryChunksResponse(chunks=chunks, scores=scores) class WeaviateVectorIOAdapter( @@ -172,9 +294,9 @@ class WeaviateVectorIOAdapter( self.openai_vector_stores: dict[str, dict[str, Any]] = {} self.metadata_collection_name = "openai_vector_stores_metadata" - def _get_client(self) -> weaviate.Client: + def _get_client(self) -> weaviate.WeaviateClient: if "localhost" in self.config.weaviate_cluster_url: - log.info("using Weaviate locally in container") + log.info("Using Weaviate locally in container") host, port = self.config.weaviate_cluster_url.split(":") key = "local_test" client = weaviate.connect_to_local( @@ -247,7 +369,7 @@ class WeaviateVectorIOAdapter( ], ) - self.cache[sanitized_collection_name] = VectorDBWithIndex( + self.cache[vector_db.identifier] = VectorDBWithIndex( vector_db, WeaviateIndex(client=client, collection_name=sanitized_collection_name), self.inference_api, @@ -256,32 +378,34 @@ class WeaviateVectorIOAdapter( async def unregister_vector_db(self, vector_db_id: str) -> None: client = self._get_client() sanitized_collection_name = sanitize_collection_name(vector_db_id, weaviate_format=True) - if sanitized_collection_name not in self.cache or client.collections.exists(sanitized_collection_name) is False: - log.warning(f"Vector DB {sanitized_collection_name} not found") + if vector_db_id not in self.cache or client.collections.exists(sanitized_collection_name) is False: return client.collections.delete(sanitized_collection_name) - await self.cache[sanitized_collection_name].index.delete() - del self.cache[sanitized_collection_name] + await self.cache[vector_db_id].index.delete() + del self.cache[vector_db_id] async def _get_and_cache_vector_db_index(self, vector_db_id: str) -> VectorDBWithIndex | None: - sanitized_collection_name = sanitize_collection_name(vector_db_id, weaviate_format=True) - if sanitized_collection_name in self.cache: - return self.cache[sanitized_collection_name] + if vector_db_id in self.cache: + return self.cache[vector_db_id] - vector_db = await self.vector_db_store.get_vector_db(sanitized_collection_name) + if self.vector_db_store is None: + raise VectorStoreNotFoundError(vector_db_id) + + vector_db = await self.vector_db_store.get_vector_db(vector_db_id) if not vector_db: raise VectorStoreNotFoundError(vector_db_id) client = self._get_client() - if not client.collections.exists(vector_db.identifier): + sanitized_collection_name = sanitize_collection_name(vector_db.identifier, weaviate_format=True) + if not client.collections.exists(sanitized_collection_name): raise ValueError(f"Collection with name `{sanitized_collection_name}` not found") index = VectorDBWithIndex( vector_db=vector_db, - index=WeaviateIndex(client=client, collection_name=sanitized_collection_name), + index=WeaviateIndex(client=client, collection_name=vector_db.identifier), inference_api=self.inference_api, ) - self.cache[sanitized_collection_name] = index + self.cache[vector_db_id] = index return index async def insert_chunks( @@ -290,8 +414,7 @@ class WeaviateVectorIOAdapter( chunks: list[Chunk], ttl_seconds: int | None = None, ) -> None: - sanitized_collection_name = sanitize_collection_name(vector_db_id, weaviate_format=True) - index = await self._get_and_cache_vector_db_index(sanitized_collection_name) + index = await self._get_and_cache_vector_db_index(vector_db_id) if not index: raise VectorStoreNotFoundError(vector_db_id) @@ -303,17 +426,15 @@ class WeaviateVectorIOAdapter( query: InterleavedContent, params: dict[str, Any] | None = None, ) -> QueryChunksResponse: - sanitized_collection_name = sanitize_collection_name(vector_db_id, weaviate_format=True) - index = await self._get_and_cache_vector_db_index(sanitized_collection_name) + index = await self._get_and_cache_vector_db_index(vector_db_id) if not index: raise VectorStoreNotFoundError(vector_db_id) return await index.query_chunks(query, params) async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None: - sanitized_collection_name = sanitize_collection_name(store_id, weaviate_format=True) - index = await self._get_and_cache_vector_db_index(sanitized_collection_name) + index = await self._get_and_cache_vector_db_index(store_id) if not index: - raise ValueError(f"Vector DB {sanitized_collection_name} not found") + raise ValueError(f"Vector DB {store_id} not found") await index.index.delete_chunks(chunks_for_deletion) diff --git a/llama_stack/providers/utils/memory/vector_store.py b/llama_stack/providers/utils/memory/vector_store.py index aaa470970..857fbe910 100644 --- a/llama_stack/providers/utils/memory/vector_store.py +++ b/llama_stack/providers/utils/memory/vector_store.py @@ -50,6 +50,7 @@ class ChunkForDeletion(BaseModel): # Constants for reranker types RERANKER_TYPE_RRF = "rrf" RERANKER_TYPE_WEIGHTED = "weighted" +RERANKER_TYPE_NORMALIZED = "normalized" def parse_pdf(data: bytes) -> str: @@ -325,6 +326,8 @@ class VectorDBWithIndex: weights = ranker.get("params", {}).get("weights", [0.5, 0.5]) reranker_type = RERANKER_TYPE_WEIGHTED reranker_params = {"alpha": weights[0] if len(weights) > 0 else 0.5} + elif strategy == "normalized": + reranker_type = RERANKER_TYPE_NORMALIZED else: reranker_type = RERANKER_TYPE_RRF k_value = ranker.get("params", {}).get("k", 60.0) diff --git a/tests/integration/vector_io/test_openai_vector_stores.py b/tests/integration/vector_io/test_openai_vector_stores.py index c67036eab..0c60acd27 100644 --- a/tests/integration/vector_io/test_openai_vector_stores.py +++ b/tests/integration/vector_io/test_openai_vector_stores.py @@ -22,16 +22,16 @@ def skip_if_provider_doesnt_support_openai_vector_stores(client_with_models): vector_io_providers = [p for p in client_with_models.providers.list() if p.api == "vector_io"] for p in vector_io_providers: if p.provider_type in [ - "inline::faiss", - "inline::sqlite-vec", - "inline::milvus", "inline::chromadb", - "remote::pgvector", - "remote::chromadb", - "remote::qdrant", + "inline::faiss", + "inline::milvus", "inline::qdrant", - "remote::weaviate", + "inline::sqlite-vec", + "remote::chromadb", "remote::milvus", + "remote::pgvector", + "remote::qdrant", + "remote::weaviate", ]: return @@ -47,23 +47,25 @@ def skip_if_provider_doesnt_support_openai_vector_stores_search(client_with_mode "inline::milvus", "inline::chromadb", "inline::qdrant", - "remote::pgvector", "remote::chromadb", - "remote::weaviate", - "remote::qdrant", "remote::milvus", + "remote::pgvector", + "remote::qdrant", + "remote::weaviate", ], "keyword": [ + "inline::milvus", "inline::sqlite-vec", "remote::milvus", - "inline::milvus", "remote::pgvector", + "remote::weaviate", ], "hybrid": [ - "inline::sqlite-vec", "inline::milvus", + "inline::sqlite-vec", "remote::milvus", "remote::pgvector", + "remote::weaviate", ], } supported_providers = search_mode_support.get(search_mode, []) diff --git a/tests/integration/vector_io/test_vector_io.py b/tests/integration/vector_io/test_vector_io.py index 979eff6bb..7bfe31dd6 100644 --- a/tests/integration/vector_io/test_vector_io.py +++ b/tests/integration/vector_io/test_vector_io.py @@ -138,8 +138,8 @@ def test_insert_chunks(client_with_empty_registry, embedding_model_id, embedding def test_insert_chunks_with_precomputed_embeddings(client_with_empty_registry, embedding_model_id, embedding_dimension): vector_io_provider_params_dict = { "inline::milvus": {"score_threshold": -1.0}, - "remote::qdrant": {"score_threshold": -1.0}, "inline::qdrant": {"score_threshold": -1.0}, + "remote::qdrant": {"score_threshold": -1.0}, } vector_db_name = "test_precomputed_embeddings_db" register_response = client_with_empty_registry.vector_dbs.register( diff --git a/tests/unit/providers/vector_io/conftest.py b/tests/unit/providers/vector_io/conftest.py index 91bddd037..70ace695e 100644 --- a/tests/unit/providers/vector_io/conftest.py +++ b/tests/unit/providers/vector_io/conftest.py @@ -26,13 +26,15 @@ from llama_stack.providers.remote.vector_io.milvus.milvus import MilvusIndex, Mi from llama_stack.providers.remote.vector_io.pgvector.config import PGVectorVectorIOConfig from llama_stack.providers.remote.vector_io.pgvector.pgvector import PGVectorIndex, PGVectorVectorIOAdapter from llama_stack.providers.remote.vector_io.qdrant.qdrant import QdrantVectorIOAdapter +from llama_stack.providers.remote.vector_io.weaviate.config import WeaviateVectorIOConfig +from llama_stack.providers.remote.vector_io.weaviate.weaviate import WeaviateIndex, WeaviateVectorIOAdapter EMBEDDING_DIMENSION = 384 COLLECTION_PREFIX = "test_collection" MILVUS_ALIAS = "test_milvus" -@pytest.fixture(params=["milvus", "sqlite_vec", "faiss", "chroma", "pgvector"]) +@pytest.fixture(params=["milvus", "sqlite_vec", "faiss", "chroma", "pgvector", "weaviate"]) def vector_provider(request): return request.param @@ -448,6 +450,71 @@ async def pgvector_vec_adapter(mock_inference_api, embedding_dimension): await adapter.shutdown() +@pytest.fixture(scope="session") +def weaviate_vec_db_path(tmp_path_factory): + db_path = str(tmp_path_factory.getbasetemp() / "test_weaviate.db") + return db_path + + +@pytest.fixture +async def weaviate_vec_index(weaviate_vec_db_path): + import pytest_socket + import weaviate + + pytest_socket.enable_socket() + client = weaviate.connect_to_embedded( + hostname="localhost", + port=8080, + grpc_port=50051, + persistence_data_path=weaviate_vec_db_path, + ) + index = WeaviateIndex(client=client, collection_name="Testcollection") + await index.initialize() + yield index + await index.delete() + client.close() + + +@pytest.fixture +async def weaviate_vec_adapter(weaviate_vec_db_path, mock_inference_api, embedding_dimension): + import pytest_socket + import weaviate + + pytest_socket.enable_socket() + + client = weaviate.connect_to_embedded( + hostname="localhost", + port=8080, + grpc_port=50051, + persistence_data_path=weaviate_vec_db_path, + ) + + config = WeaviateVectorIOConfig( + weaviate_cluster_url="localhost:8080", + weaviate_api_key=None, + kvstore=SqliteKVStoreConfig(), + ) + adapter = WeaviateVectorIOAdapter( + config=config, + inference_api=mock_inference_api, + files_api=None, + ) + collection_id = f"weaviate_test_collection_{random.randint(1, 1_000_000)}" + await adapter.initialize() + await adapter.register_vector_db( + VectorDB( + identifier=collection_id, + provider_id="test_provider", + embedding_model="test_model", + embedding_dimension=embedding_dimension, + ) + ) + adapter.test_collection_id = collection_id + yield adapter + await adapter.shutdown() + client.close() + + @pytest.fixture def vector_io_adapter(vector_provider, request): vector_provider_dict = { @@ -457,6 +524,7 @@ def vector_io_adapter(vector_provider, request): "chroma": "chroma_vec_adapter", "qdrant": "qdrant_vec_adapter", "pgvector": "pgvector_vec_adapter", + "weaviate": "weaviate_vec_adapter", } return request.getfixturevalue(vector_provider_dict[vector_provider])