mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 04:04:14 +00:00
feat: implement keyword and hybrid search for Weaviate provider
This commit is contained in:
parent
a1301911e4
commit
4541b517c8
8 changed files with 476 additions and 25 deletions
|
@ -10,7 +10,7 @@ import weaviate
|
||||||
import weaviate.classes as wvc
|
import weaviate.classes as wvc
|
||||||
from numpy.typing import NDArray
|
from numpy.typing import NDArray
|
||||||
from weaviate.classes.init import Auth
|
from weaviate.classes.init import Auth
|
||||||
from weaviate.classes.query import Filter
|
from weaviate.classes.query import Filter, HybridFusion
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import InterleavedContent
|
from llama_stack.apis.common.content_types import InterleavedContent
|
||||||
from llama_stack.apis.common.errors import VectorStoreNotFoundError
|
from llama_stack.apis.common.errors import VectorStoreNotFoundError
|
||||||
|
@ -26,6 +26,7 @@ from llama_stack.providers.utils.memory.openai_vector_store_mixin import (
|
||||||
OpenAIVectorStoreMixin,
|
OpenAIVectorStoreMixin,
|
||||||
)
|
)
|
||||||
from llama_stack.providers.utils.memory.vector_store import (
|
from llama_stack.providers.utils.memory.vector_store import (
|
||||||
|
RERANKER_TYPE_RRF,
|
||||||
ChunkForDeletion,
|
ChunkForDeletion,
|
||||||
EmbeddingIndex,
|
EmbeddingIndex,
|
||||||
VectorDBWithIndex,
|
VectorDBWithIndex,
|
||||||
|
@ -88,6 +89,9 @@ class WeaviateIndex(EmbeddingIndex):
|
||||||
collection.data.delete_many(where=Filter.by_property("chunk_id").contains_any(chunk_ids))
|
collection.data.delete_many(where=Filter.by_property("chunk_id").contains_any(chunk_ids))
|
||||||
|
|
||||||
async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||||
|
log.info(
|
||||||
|
f"WEAVIATE VECTOR SEARCH CALLED: embedding_shape={embedding.shape}, k={k}, threshold={score_threshold}"
|
||||||
|
)
|
||||||
sanitized_collection_name = sanitize_collection_name(self.collection_name, weaviate_format=True)
|
sanitized_collection_name = sanitize_collection_name(self.collection_name, weaviate_format=True)
|
||||||
collection = self.client.collections.get(sanitized_collection_name)
|
collection = self.client.collections.get(sanitized_collection_name)
|
||||||
|
|
||||||
|
@ -115,6 +119,7 @@ class WeaviateIndex(EmbeddingIndex):
|
||||||
chunks.append(chunk)
|
chunks.append(chunk)
|
||||||
scores.append(score)
|
scores.append(score)
|
||||||
|
|
||||||
|
log.info(f"WEAVIATE VECTOR SEARCH RESULTS: Found {len(chunks)} chunks with scores {scores}")
|
||||||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||||
|
|
||||||
async def delete(self, chunk_ids: list[str] | None = None) -> None:
|
async def delete(self, chunk_ids: list[str] | None = None) -> None:
|
||||||
|
@ -136,7 +141,46 @@ class WeaviateIndex(EmbeddingIndex):
|
||||||
k: int,
|
k: int,
|
||||||
score_threshold: float,
|
score_threshold: float,
|
||||||
) -> QueryChunksResponse:
|
) -> QueryChunksResponse:
|
||||||
raise NotImplementedError("Keyword search is not supported in Weaviate")
|
"""
|
||||||
|
Performs BM25-based keyword search using Weaviate's built-in full-text search.
|
||||||
|
Args:
|
||||||
|
query_string: The text query for keyword search
|
||||||
|
k: Limit of number of results to return
|
||||||
|
score_threshold: Minimum similarity score threshold
|
||||||
|
Returns:
|
||||||
|
QueryChunksResponse with combined results
|
||||||
|
"""
|
||||||
|
log.info(f"WEAVIATE KEYWORD SEARCH CALLED: query='{query_string}', k={k}, threshold={score_threshold}")
|
||||||
|
sanitized_collection_name = sanitize_collection_name(self.collection_name, weaviate_format=True)
|
||||||
|
collection = self.client.collections.get(sanitized_collection_name)
|
||||||
|
|
||||||
|
# Perform BM25 keyword search on chunk_content field
|
||||||
|
results = collection.query.bm25(
|
||||||
|
query=query_string,
|
||||||
|
limit=k,
|
||||||
|
return_metadata=wvc.query.MetadataQuery(score=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
chunks = []
|
||||||
|
scores = []
|
||||||
|
for doc in results.objects:
|
||||||
|
chunk_json = doc.properties["chunk_content"]
|
||||||
|
try:
|
||||||
|
chunk_dict = json.loads(chunk_json)
|
||||||
|
chunk = Chunk(**chunk_dict)
|
||||||
|
except Exception:
|
||||||
|
log.exception(f"Failed to parse document: {chunk_json}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
score = doc.metadata.score if doc.metadata.score is not None else 0.0
|
||||||
|
if score < score_threshold:
|
||||||
|
continue
|
||||||
|
|
||||||
|
chunks.append(chunk)
|
||||||
|
scores.append(score)
|
||||||
|
|
||||||
|
log.info(f"WEAVIATE KEYWORD SEARCH RESULTS: Found {len(chunks)} chunks with scores {scores}.")
|
||||||
|
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||||
|
|
||||||
async def query_hybrid(
|
async def query_hybrid(
|
||||||
self,
|
self,
|
||||||
|
@ -147,7 +191,62 @@ class WeaviateIndex(EmbeddingIndex):
|
||||||
reranker_type: str,
|
reranker_type: str,
|
||||||
reranker_params: dict[str, Any] | None = None,
|
reranker_params: dict[str, Any] | None = None,
|
||||||
) -> QueryChunksResponse:
|
) -> QueryChunksResponse:
|
||||||
raise NotImplementedError("Hybrid search is not supported in Weaviate")
|
"""
|
||||||
|
Hybrid search combining vector similarity and keyword search using Weaviate's native hybrid search.
|
||||||
|
Args:
|
||||||
|
embedding: The query embedding vector
|
||||||
|
query_string: The text query for keyword search
|
||||||
|
k: Limit of number of results to return
|
||||||
|
score_threshold: Minimum similarity score threshold
|
||||||
|
reranker_type: Type of reranker to use ("rrf" or "normalized")
|
||||||
|
reranker_params: Parameters for the reranker
|
||||||
|
Returns:
|
||||||
|
QueryChunksResponse with combined results
|
||||||
|
"""
|
||||||
|
log.info(
|
||||||
|
f"WEAVIATE HYBRID SEARCH CALLED: query='{query_string}', embedding_shape={embedding.shape}, k={k}, threshold={score_threshold}, reranker={reranker_type}"
|
||||||
|
)
|
||||||
|
sanitized_collection_name = sanitize_collection_name(self.collection_name, weaviate_format=True)
|
||||||
|
collection = self.client.collections.get(sanitized_collection_name)
|
||||||
|
|
||||||
|
# Ranked (RRF) reranker fusion type
|
||||||
|
if reranker_type == RERANKER_TYPE_RRF:
|
||||||
|
rerank = HybridFusion.RANKED
|
||||||
|
# Relative score (Normalized) reranker fusion type
|
||||||
|
else:
|
||||||
|
rerank = HybridFusion.RELATIVE_SCORE
|
||||||
|
|
||||||
|
# Perform hybrid search using Weaviate's native hybrid search
|
||||||
|
results = collection.query.hybrid(
|
||||||
|
query=query_string,
|
||||||
|
alpha=0.5, # Range <0, 1>, where 0.5 will equally favor vector and keyword search
|
||||||
|
vector=embedding.tolist(),
|
||||||
|
limit=k,
|
||||||
|
fusion_type=rerank,
|
||||||
|
return_metadata=wvc.query.MetadataQuery(score=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
chunks = []
|
||||||
|
scores = []
|
||||||
|
for doc in results.objects:
|
||||||
|
chunk_json = doc.properties["chunk_content"]
|
||||||
|
try:
|
||||||
|
chunk_dict = json.loads(chunk_json)
|
||||||
|
chunk = Chunk(**chunk_dict)
|
||||||
|
except Exception:
|
||||||
|
log.exception(f"Failed to parse document: {chunk_json}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
score = doc.metadata.score if doc.metadata.score is not None else 0.0
|
||||||
|
if score < score_threshold:
|
||||||
|
continue
|
||||||
|
|
||||||
|
log.info(f"Document {chunk.metadata.get('document_id')} has score {score}")
|
||||||
|
chunks.append(chunk)
|
||||||
|
scores.append(score)
|
||||||
|
|
||||||
|
log.info(f"WEAVIATE HYBRID SEARCH RESULTS: Found {len(chunks)} chunks with scores {scores}")
|
||||||
|
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||||
|
|
||||||
|
|
||||||
class WeaviateVectorIOAdapter(
|
class WeaviateVectorIOAdapter(
|
||||||
|
|
|
@ -50,6 +50,7 @@ class ChunkForDeletion(BaseModel):
|
||||||
# Constants for reranker types
|
# Constants for reranker types
|
||||||
RERANKER_TYPE_RRF = "rrf"
|
RERANKER_TYPE_RRF = "rrf"
|
||||||
RERANKER_TYPE_WEIGHTED = "weighted"
|
RERANKER_TYPE_WEIGHTED = "weighted"
|
||||||
|
RERANKER_TYPE_NORMALIZED = "normalized"
|
||||||
|
|
||||||
|
|
||||||
def parse_pdf(data: bytes) -> str:
|
def parse_pdf(data: bytes) -> str:
|
||||||
|
@ -325,6 +326,8 @@ class VectorDBWithIndex:
|
||||||
weights = ranker.get("params", {}).get("weights", [0.5, 0.5])
|
weights = ranker.get("params", {}).get("weights", [0.5, 0.5])
|
||||||
reranker_type = RERANKER_TYPE_WEIGHTED
|
reranker_type = RERANKER_TYPE_WEIGHTED
|
||||||
reranker_params = {"alpha": weights[0] if len(weights) > 0 else 0.5}
|
reranker_params = {"alpha": weights[0] if len(weights) > 0 else 0.5}
|
||||||
|
elif strategy == "normalized":
|
||||||
|
reranker_type = RERANKER_TYPE_NORMALIZED
|
||||||
else:
|
else:
|
||||||
reranker_type = RERANKER_TYPE_RRF
|
reranker_type = RERANKER_TYPE_RRF
|
||||||
k_value = ranker.get("params", {}).get("k", 60.0)
|
k_value = ranker.get("params", {}).get("k", 60.0)
|
||||||
|
|
|
@ -25,8 +25,8 @@ classifiers = [
|
||||||
]
|
]
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"aiohttp",
|
"aiohttp",
|
||||||
"fastapi>=0.115.0,<1.0", # server
|
"fastapi>=0.115.0,<1.0", # server
|
||||||
"fire", # for MCP in LLS client
|
"fire", # for MCP in LLS client
|
||||||
"httpx",
|
"httpx",
|
||||||
"huggingface-hub>=0.34.0,<1.0",
|
"huggingface-hub>=0.34.0,<1.0",
|
||||||
"jinja2>=3.1.6",
|
"jinja2>=3.1.6",
|
||||||
|
@ -43,12 +43,13 @@ dependencies = [
|
||||||
"tiktoken",
|
"tiktoken",
|
||||||
"pillow",
|
"pillow",
|
||||||
"h11>=0.16.0",
|
"h11>=0.16.0",
|
||||||
"python-multipart>=0.0.20", # For fastapi Form
|
"python-multipart>=0.0.20", # For fastapi Form
|
||||||
"uvicorn>=0.34.0", # server
|
"uvicorn>=0.34.0", # server
|
||||||
"opentelemetry-sdk>=1.30.0", # server
|
"opentelemetry-sdk>=1.30.0", # server
|
||||||
"opentelemetry-exporter-otlp-proto-http>=1.30.0", # server
|
"opentelemetry-exporter-otlp-proto-http>=1.30.0", # server
|
||||||
"aiosqlite>=0.21.0", # server - for metadata store
|
"aiosqlite>=0.21.0", # server - for metadata store
|
||||||
"asyncpg", # for metadata store
|
"asyncpg", # for metadata store
|
||||||
|
"weaviate-client>=4.16.5",
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
|
|
|
@ -22,16 +22,16 @@ def skip_if_provider_doesnt_support_openai_vector_stores(client_with_models):
|
||||||
vector_io_providers = [p for p in client_with_models.providers.list() if p.api == "vector_io"]
|
vector_io_providers = [p for p in client_with_models.providers.list() if p.api == "vector_io"]
|
||||||
for p in vector_io_providers:
|
for p in vector_io_providers:
|
||||||
if p.provider_type in [
|
if p.provider_type in [
|
||||||
"inline::faiss",
|
|
||||||
"inline::sqlite-vec",
|
|
||||||
"inline::milvus",
|
|
||||||
"inline::chromadb",
|
"inline::chromadb",
|
||||||
"remote::pgvector",
|
"inline::faiss",
|
||||||
"remote::chromadb",
|
"inline::milvus",
|
||||||
"remote::qdrant",
|
|
||||||
"inline::qdrant",
|
"inline::qdrant",
|
||||||
"remote::weaviate",
|
"inline::sqlite-vec",
|
||||||
|
"remote::chromadb",
|
||||||
"remote::milvus",
|
"remote::milvus",
|
||||||
|
"remote::pgvector",
|
||||||
|
"remote::qdrant",
|
||||||
|
"remote::weaviate",
|
||||||
]:
|
]:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -47,23 +47,25 @@ def skip_if_provider_doesnt_support_openai_vector_stores_search(client_with_mode
|
||||||
"inline::milvus",
|
"inline::milvus",
|
||||||
"inline::chromadb",
|
"inline::chromadb",
|
||||||
"inline::qdrant",
|
"inline::qdrant",
|
||||||
"remote::pgvector",
|
|
||||||
"remote::chromadb",
|
"remote::chromadb",
|
||||||
"remote::weaviate",
|
|
||||||
"remote::qdrant",
|
|
||||||
"remote::milvus",
|
"remote::milvus",
|
||||||
|
"remote::pgvector",
|
||||||
|
"remote::qdrant",
|
||||||
|
"remote::weaviate",
|
||||||
],
|
],
|
||||||
"keyword": [
|
"keyword": [
|
||||||
|
"inline::milvus",
|
||||||
"inline::sqlite-vec",
|
"inline::sqlite-vec",
|
||||||
"remote::milvus",
|
"remote::milvus",
|
||||||
"inline::milvus",
|
|
||||||
"remote::pgvector",
|
"remote::pgvector",
|
||||||
|
"remote::weaviate",
|
||||||
],
|
],
|
||||||
"hybrid": [
|
"hybrid": [
|
||||||
"inline::sqlite-vec",
|
|
||||||
"inline::milvus",
|
"inline::milvus",
|
||||||
|
"inline::sqlite-vec",
|
||||||
"remote::milvus",
|
"remote::milvus",
|
||||||
"remote::pgvector",
|
"remote::pgvector",
|
||||||
|
"remote::weaviate",
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
supported_providers = search_mode_support.get(search_mode, [])
|
supported_providers = search_mode_support.get(search_mode, [])
|
||||||
|
|
|
@ -26,13 +26,15 @@ from llama_stack.providers.remote.vector_io.milvus.milvus import MilvusIndex, Mi
|
||||||
from llama_stack.providers.remote.vector_io.pgvector.config import PGVectorVectorIOConfig
|
from llama_stack.providers.remote.vector_io.pgvector.config import PGVectorVectorIOConfig
|
||||||
from llama_stack.providers.remote.vector_io.pgvector.pgvector import PGVectorIndex, PGVectorVectorIOAdapter
|
from llama_stack.providers.remote.vector_io.pgvector.pgvector import PGVectorIndex, PGVectorVectorIOAdapter
|
||||||
from llama_stack.providers.remote.vector_io.qdrant.qdrant import QdrantVectorIOAdapter
|
from llama_stack.providers.remote.vector_io.qdrant.qdrant import QdrantVectorIOAdapter
|
||||||
|
from llama_stack.providers.remote.vector_io.weaviate.config import WeaviateVectorIOConfig
|
||||||
|
from llama_stack.providers.remote.vector_io.weaviate.weaviate import WeaviateIndex, WeaviateVectorIOAdapter
|
||||||
|
|
||||||
EMBEDDING_DIMENSION = 384
|
EMBEDDING_DIMENSION = 384
|
||||||
COLLECTION_PREFIX = "test_collection"
|
COLLECTION_PREFIX = "test_collection"
|
||||||
MILVUS_ALIAS = "test_milvus"
|
MILVUS_ALIAS = "test_milvus"
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(params=["milvus", "sqlite_vec", "faiss", "chroma", "pgvector"])
|
@pytest.fixture(params=["milvus", "sqlite_vec", "faiss", "chroma", "pgvector", "weaviate"])
|
||||||
def vector_provider(request):
|
def vector_provider(request):
|
||||||
return request.param
|
return request.param
|
||||||
|
|
||||||
|
@ -446,6 +448,75 @@ async def pgvector_vec_adapter(mock_inference_api, embedding_dimension):
|
||||||
|
|
||||||
yield adapter
|
yield adapter
|
||||||
await adapter.shutdown()
|
await adapter.shutdown()
|
||||||
|
def weaviate_vec_db_path():
|
||||||
|
return "localhost:8080"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def weaviate_vec_index(weaviate_vec_db_path, embedding_dimension):
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
import weaviate
|
||||||
|
|
||||||
|
# Connect to local Weaviate instance
|
||||||
|
client = weaviate.connect_to_local(
|
||||||
|
host="localhost",
|
||||||
|
port=8080,
|
||||||
|
)
|
||||||
|
|
||||||
|
collection_name = f"{COLLECTION_PREFIX}_{uuid.uuid4()}"
|
||||||
|
index = WeaviateIndex(client=client, collection_name=collection_name)
|
||||||
|
|
||||||
|
# Create the collection for this test
|
||||||
|
import weaviate.classes as wvc
|
||||||
|
from weaviate.collections.classes.config import _CollectionConfig
|
||||||
|
|
||||||
|
from llama_stack.providers.utils.vector_io.vector_utils import sanitize_collection_name
|
||||||
|
|
||||||
|
sanitized_name = sanitize_collection_name(collection_name, weaviate_format=True)
|
||||||
|
collection_config = _CollectionConfig(
|
||||||
|
name=sanitized_name,
|
||||||
|
vectorizer_config=wvc.config.Configure.Vectorizer.none(),
|
||||||
|
properties=[
|
||||||
|
wvc.config.Property(
|
||||||
|
name="chunk_content",
|
||||||
|
data_type=wvc.config.DataType.TEXT,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
if not client.collections.exists(sanitized_name):
|
||||||
|
client.collections.create_from_config(collection_config)
|
||||||
|
|
||||||
|
yield index
|
||||||
|
await index.delete()
|
||||||
|
client.close()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def weaviate_vec_adapter(weaviate_vec_db_path, mock_inference_api, embedding_dimension):
|
||||||
|
config = WeaviateVectorIOConfig(
|
||||||
|
weaviate_cluster_url=weaviate_vec_db_path,
|
||||||
|
weaviate_api_key=None,
|
||||||
|
kvstore=SqliteKVStoreConfig(),
|
||||||
|
)
|
||||||
|
adapter = WeaviateVectorIOAdapter(
|
||||||
|
config=config,
|
||||||
|
inference_api=mock_inference_api,
|
||||||
|
files_api=None,
|
||||||
|
)
|
||||||
|
collection_id = f"weaviate_test_collection_{random.randint(1, 1_000_000)}"
|
||||||
|
await adapter.initialize()
|
||||||
|
await adapter.register_vector_db(
|
||||||
|
VectorDB(
|
||||||
|
identifier=collection_id,
|
||||||
|
provider_id="test_provider",
|
||||||
|
embedding_model="test_model",
|
||||||
|
embedding_dimension=embedding_dimension,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
adapter.test_collection_id = collection_id
|
||||||
|
yield adapter
|
||||||
|
await adapter.shutdown()
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
@ -457,6 +528,7 @@ def vector_io_adapter(vector_provider, request):
|
||||||
"chroma": "chroma_vec_adapter",
|
"chroma": "chroma_vec_adapter",
|
||||||
"qdrant": "qdrant_vec_adapter",
|
"qdrant": "qdrant_vec_adapter",
|
||||||
"pgvector": "pgvector_vec_adapter",
|
"pgvector": "pgvector_vec_adapter",
|
||||||
|
"weaviate": "weaviate_vec_adapter",
|
||||||
}
|
}
|
||||||
return request.getfixturevalue(vector_provider_dict[vector_provider])
|
return request.getfixturevalue(vector_provider_dict[vector_provider])
|
||||||
|
|
||||||
|
|
|
@ -23,13 +23,13 @@ pymilvus_mock.AnnSearchRequest = MagicMock
|
||||||
with patch.dict("sys.modules", {"pymilvus": pymilvus_mock}):
|
with patch.dict("sys.modules", {"pymilvus": pymilvus_mock}):
|
||||||
from llama_stack.providers.remote.vector_io.milvus.milvus import MilvusIndex
|
from llama_stack.providers.remote.vector_io.milvus.milvus import MilvusIndex
|
||||||
|
|
||||||
# This test is a unit test for the MilvusVectorIOAdapter class. This should only contain
|
# This test is a unit test for the MilvusIndex class. This should only contain
|
||||||
# tests which are specific to this class. More general (API-level) tests should be placed in
|
# tests which are specific to this class. More general (API-level) tests should be placed in
|
||||||
# tests/integration/vector_io/
|
# tests/integration/vector_io/
|
||||||
#
|
#
|
||||||
# How to run this test:
|
# How to run this test:
|
||||||
#
|
#
|
||||||
# pytest tests/unit/providers/vector_io/test_milvus.py \
|
# pytest tests/unit/providers/vector_io/remote/test_milvus.py \
|
||||||
# -v -s --tb=short --disable-warnings --asyncio-mode=auto
|
# -v -s --tb=short --disable-warnings --asyncio-mode=auto
|
||||||
|
|
||||||
MILVUS_PROVIDER = "milvus"
|
MILVUS_PROVIDER = "milvus"
|
||||||
|
@ -324,3 +324,6 @@ async def test_query_hybrid_search_default_rrf(
|
||||||
call_args = mock_milvus_client.hybrid_search.call_args
|
call_args = mock_milvus_client.hybrid_search.call_args
|
||||||
ranker = call_args[1]["ranker"]
|
ranker = call_args[1]["ranker"]
|
||||||
assert ranker is not None
|
assert ranker is not None
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: Write tests for the MilvusVectorIOAdapter class.
|
||||||
|
|
269
tests/unit/providers/vector_io/remote/test_weaviate.py
Normal file
269
tests/unit/providers/vector_io/remote/test_weaviate.py
Normal file
|
@ -0,0 +1,269 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from unittest.mock import ANY, MagicMock, patch
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from llama_stack.apis.vector_io import QueryChunksResponse
|
||||||
|
|
||||||
|
# Mock the Weaviate client
|
||||||
|
weaviate_mock = MagicMock()
|
||||||
|
|
||||||
|
# Apply the mock before importing WeaviateIndex
|
||||||
|
with patch.dict("sys.modules", {"weaviate": weaviate_mock}):
|
||||||
|
from llama_stack.providers.remote.vector_io.weaviate.weaviate import WeaviateIndex
|
||||||
|
|
||||||
|
# This test is a unit test for the WeaviateIndex class. This should only contain
|
||||||
|
# tests which are specific to this class. More general (API-level) tests should be placed in
|
||||||
|
# tests/integration/vector_io/
|
||||||
|
#
|
||||||
|
# How to run this test:
|
||||||
|
#
|
||||||
|
# pytest tests/unit/providers/vector_io/remote/test_weaviate.py \
|
||||||
|
# -v -s --tb=short --disable-warnings --asyncio-mode=auto
|
||||||
|
|
||||||
|
WEAVIATE_PROVIDER = "weaviate"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def mock_weaviate_client() -> MagicMock:
|
||||||
|
"""Create a mock Weaviate client with common method behaviors."""
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_collection = MagicMock()
|
||||||
|
|
||||||
|
# Mock collection data operations
|
||||||
|
mock_collection.data.insert_many.return_value = None
|
||||||
|
mock_collection.data.delete_many.return_value = None
|
||||||
|
|
||||||
|
# Mock collection search operations
|
||||||
|
mock_collection.query.near_vector.return_value = None
|
||||||
|
mock_collection.query.bm25.return_value = None
|
||||||
|
mock_collection.query.hybrid.return_value = None
|
||||||
|
mock_results = MagicMock()
|
||||||
|
mock_results.objects = [MagicMock(), MagicMock()]
|
||||||
|
mock_collection.query.near_vector.return_value = mock_results
|
||||||
|
|
||||||
|
# Mock client collection operations
|
||||||
|
mock_client.collections.get.return_value = mock_collection
|
||||||
|
mock_client.collections.exists.return_value = True
|
||||||
|
mock_client.collections.delete.return_value = None
|
||||||
|
|
||||||
|
# Mock client close operation
|
||||||
|
mock_client.close.return_value = None
|
||||||
|
|
||||||
|
return mock_client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def weaviate_index(mock_weaviate_client):
|
||||||
|
"""Create a WeaviateIndex with mocked client."""
|
||||||
|
index = WeaviateIndex(client=mock_weaviate_client, collection_name="Testcollection")
|
||||||
|
yield index
|
||||||
|
# No real cleanup needed since we're using mocks
|
||||||
|
|
||||||
|
|
||||||
|
async def test_add_chunks(weaviate_index, sample_chunks, sample_embeddings, mock_weaviate_client):
|
||||||
|
# Setup: Add chunks first
|
||||||
|
await weaviate_index.add_chunks(sample_chunks, sample_embeddings)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
mock_weaviate_client.collections.get.assert_called_once_with("Testcollection")
|
||||||
|
mock_weaviate_client.collections.get.return_value.data.insert_many.assert_called_once()
|
||||||
|
|
||||||
|
# Verify the insert call had the right number of chunks
|
||||||
|
data_objects, _ = mock_weaviate_client.collections.get.return_value.data.insert_many.call_args
|
||||||
|
assert len(data_objects[0]) == len(sample_chunks)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_query_chunks_vector(
|
||||||
|
weaviate_index, sample_chunks, sample_embeddings, embedding_dimension, mock_weaviate_client
|
||||||
|
):
|
||||||
|
# Setup: Add chunks first
|
||||||
|
await weaviate_index.add_chunks(sample_chunks, sample_embeddings)
|
||||||
|
|
||||||
|
# Create mock objects that match Weaviate's response structure
|
||||||
|
mock_objects = []
|
||||||
|
for i, chunk in enumerate(sample_chunks[:2]): # Return first 2 chunks
|
||||||
|
mock_obj = MagicMock()
|
||||||
|
mock_obj.properties = {"chunk_content": chunk.model_dump_json()}
|
||||||
|
mock_obj.metadata.distance = 0.1 + i * 0.1 # Mock distances
|
||||||
|
mock_objects.append(mock_obj)
|
||||||
|
|
||||||
|
mock_results = MagicMock()
|
||||||
|
mock_results.objects = mock_objects
|
||||||
|
mock_weaviate_client.collections.get.return_value.query.near_vector.return_value = mock_results
|
||||||
|
|
||||||
|
# Test vector search
|
||||||
|
query_embedding = np.random.rand(embedding_dimension).astype(np.float32)
|
||||||
|
response = await weaviate_index.query_vector(query_embedding, k=2, score_threshold=0.0)
|
||||||
|
|
||||||
|
assert isinstance(response, QueryChunksResponse)
|
||||||
|
assert len(response.chunks) == 2
|
||||||
|
assert len(response.scores) == 2
|
||||||
|
mock_weaviate_client.collections.get.return_value.query.near_vector.assert_called_once_with(
|
||||||
|
near_vector=query_embedding.tolist(),
|
||||||
|
limit=2,
|
||||||
|
return_metadata=ANY,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_query_chunks_keyword_search(weaviate_index, sample_chunks, sample_embeddings, mock_weaviate_client):
|
||||||
|
await weaviate_index.add_chunks(sample_chunks, sample_embeddings)
|
||||||
|
|
||||||
|
# Find chunks that contain "Sentence 5"
|
||||||
|
matching_chunks = [chunk for chunk in sample_chunks if "Sentence 5" in chunk.content]
|
||||||
|
|
||||||
|
# Create mock objects that match Weaviate's BM25 response structure
|
||||||
|
# Return the first 2 matching chunks
|
||||||
|
mock_objects = []
|
||||||
|
for i, chunk in enumerate(matching_chunks[:2]):
|
||||||
|
mock_obj = MagicMock()
|
||||||
|
mock_obj.properties = {"chunk_content": chunk.model_dump_json()}
|
||||||
|
mock_obj.metadata.score = 0.9 - i * 0.1
|
||||||
|
mock_objects.append(mock_obj)
|
||||||
|
|
||||||
|
mock_results = MagicMock()
|
||||||
|
mock_results.objects = mock_objects
|
||||||
|
mock_weaviate_client.collections.get.return_value.query.bm25.return_value = mock_results
|
||||||
|
|
||||||
|
# Test keyword search
|
||||||
|
query_string = "Sentence 5"
|
||||||
|
response = await weaviate_index.query_keyword(query_string=query_string, k=2, score_threshold=0.0)
|
||||||
|
|
||||||
|
assert isinstance(response, QueryChunksResponse)
|
||||||
|
assert len(response.chunks) == 2
|
||||||
|
assert len(response.scores) == 2
|
||||||
|
# Verify that the returned chunks contain the search term
|
||||||
|
assert all("Sentence 5" in chunk.content for chunk in response.chunks)
|
||||||
|
mock_weaviate_client.collections.get.return_value.query.bm25.assert_called_once_with(
|
||||||
|
query=query_string,
|
||||||
|
limit=2,
|
||||||
|
return_metadata=ANY,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_delete_collection(weaviate_index, mock_weaviate_client):
|
||||||
|
# Test collection deletion (when chunk_ids is None, it deletes the entire collection)
|
||||||
|
mock_weaviate_client.collections.exists.return_value = True
|
||||||
|
|
||||||
|
await weaviate_index.delete()
|
||||||
|
|
||||||
|
mock_weaviate_client.collections.exists.assert_called_once_with("Testcollection")
|
||||||
|
mock_weaviate_client.collections.delete.assert_called_once_with("Testcollection")
|
||||||
|
|
||||||
|
|
||||||
|
async def test_delete_chunks(weaviate_index, mock_weaviate_client):
|
||||||
|
# Test deleting specific chunks using ChunkForDeletion objects
|
||||||
|
from llama_stack.providers.utils.memory.vector_store import ChunkForDeletion
|
||||||
|
|
||||||
|
chunks_for_deletion = [
|
||||||
|
ChunkForDeletion(chunk_id="chunk-1", document_id="doc-1"),
|
||||||
|
ChunkForDeletion(chunk_id="chunk-2", document_id="doc-1"),
|
||||||
|
ChunkForDeletion(chunk_id="chunk-3", document_id="doc-2"),
|
||||||
|
]
|
||||||
|
|
||||||
|
await weaviate_index.delete_chunks(chunks_for_deletion)
|
||||||
|
|
||||||
|
mock_weaviate_client.collections.get.assert_called_once_with("Testcollection")
|
||||||
|
mock_weaviate_client.collections.get.return_value.data.delete_many.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
async def test_query_hybrid_rrf(
|
||||||
|
weaviate_index, sample_chunks, sample_embeddings, embedding_dimension, mock_weaviate_client
|
||||||
|
):
|
||||||
|
# Test hybrid search with RRF reranking
|
||||||
|
from weaviate.classes.query import HybridFusion
|
||||||
|
|
||||||
|
from llama_stack.providers.utils.memory.vector_store import RERANKER_TYPE_RRF
|
||||||
|
|
||||||
|
await weaviate_index.add_chunks(sample_chunks, sample_embeddings)
|
||||||
|
|
||||||
|
# Find chunks that contain "Sentence 5"
|
||||||
|
matching_chunks = [chunk for chunk in sample_chunks if "Sentence 5" in chunk.content]
|
||||||
|
|
||||||
|
# Create mock objects for hybrid search response
|
||||||
|
mock_objects = []
|
||||||
|
for i, chunk in enumerate(matching_chunks[:2]):
|
||||||
|
mock_obj = MagicMock()
|
||||||
|
mock_obj.properties = {"chunk_content": chunk.model_dump_json()}
|
||||||
|
mock_obj.metadata.score = 0.85 + i * 0.05
|
||||||
|
mock_objects.append(mock_obj)
|
||||||
|
|
||||||
|
mock_results = MagicMock()
|
||||||
|
mock_results.objects = mock_objects
|
||||||
|
mock_weaviate_client.collections.get.return_value.query.hybrid.return_value = mock_results
|
||||||
|
|
||||||
|
# Test hybrid search with RRF
|
||||||
|
query_string = "Sentence 5"
|
||||||
|
query_embedding = np.random.rand(embedding_dimension).astype(np.float32)
|
||||||
|
response = await weaviate_index.query_hybrid(
|
||||||
|
embedding=query_embedding, query_string=query_string, k=2, score_threshold=0.0, reranker_type=RERANKER_TYPE_RRF
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(response, QueryChunksResponse)
|
||||||
|
assert len(response.chunks) == 2
|
||||||
|
assert len(response.scores) == 2
|
||||||
|
assert all("Sentence 5" in chunk.content for chunk in response.chunks)
|
||||||
|
|
||||||
|
# Verify the hybrid method was called with correct parameters
|
||||||
|
mock_weaviate_client.collections.get.return_value.query.hybrid.assert_called_once_with(
|
||||||
|
query=query_string,
|
||||||
|
alpha=0.5,
|
||||||
|
vector=query_embedding.tolist(),
|
||||||
|
limit=2,
|
||||||
|
fusion_type=HybridFusion.RANKED,
|
||||||
|
return_metadata=ANY,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_query_hybrid_normalized(
|
||||||
|
weaviate_index, sample_chunks, sample_embeddings, embedding_dimension, mock_weaviate_client
|
||||||
|
):
|
||||||
|
from weaviate.classes.query import HybridFusion
|
||||||
|
|
||||||
|
await weaviate_index.add_chunks(sample_chunks, sample_embeddings)
|
||||||
|
|
||||||
|
# Find chunks that contain "Sentence 3" for different results
|
||||||
|
matching_chunks = [chunk for chunk in sample_chunks if "Sentence 3" in chunk.content]
|
||||||
|
|
||||||
|
# Create mock objects for hybrid search response
|
||||||
|
mock_objects = []
|
||||||
|
for i, chunk in enumerate(matching_chunks[:2]):
|
||||||
|
mock_obj = MagicMock()
|
||||||
|
mock_obj.properties = {"chunk_content": chunk.model_dump_json()}
|
||||||
|
mock_obj.metadata.score = 0.75 + i * 0.1 # Mock hybrid scores
|
||||||
|
mock_objects.append(mock_obj)
|
||||||
|
|
||||||
|
mock_results = MagicMock()
|
||||||
|
mock_results.objects = mock_objects
|
||||||
|
mock_weaviate_client.collections.get.return_value.query.hybrid.return_value = mock_results
|
||||||
|
|
||||||
|
# Test hybrid search with normalized reranking
|
||||||
|
query_string = "Sentence 3"
|
||||||
|
query_embedding = np.random.rand(embedding_dimension).astype(np.float32)
|
||||||
|
response = await weaviate_index.query_hybrid(
|
||||||
|
embedding=query_embedding, query_string=query_string, k=2, score_threshold=0.0, reranker_type="normalized"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(response, QueryChunksResponse)
|
||||||
|
assert len(response.chunks) == 2
|
||||||
|
assert len(response.scores) == 2
|
||||||
|
assert all("Sentence 3" in chunk.content for chunk in response.chunks)
|
||||||
|
|
||||||
|
# Verify the hybrid method was called with correct parameters
|
||||||
|
mock_weaviate_client.collections.get.return_value.query.hybrid.assert_called_once_with(
|
||||||
|
query=query_string,
|
||||||
|
alpha=0.5,
|
||||||
|
vector=query_embedding.tolist(),
|
||||||
|
limit=2,
|
||||||
|
fusion_type=HybridFusion.RELATIVE_SCORE,
|
||||||
|
return_metadata=ANY,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: Write tests for the WeaviateVectorIOAdapter class.
|
2
uv.lock
generated
2
uv.lock
generated
|
@ -1777,6 +1777,7 @@ dependencies = [
|
||||||
{ name = "termcolor" },
|
{ name = "termcolor" },
|
||||||
{ name = "tiktoken" },
|
{ name = "tiktoken" },
|
||||||
{ name = "uvicorn" },
|
{ name = "uvicorn" },
|
||||||
|
{ name = "weaviate-client" },
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.optional-dependencies]
|
[package.optional-dependencies]
|
||||||
|
@ -1904,6 +1905,7 @@ requires-dist = [
|
||||||
{ name = "termcolor" },
|
{ name = "termcolor" },
|
||||||
{ name = "tiktoken" },
|
{ name = "tiktoken" },
|
||||||
{ name = "uvicorn", specifier = ">=0.34.0" },
|
{ name = "uvicorn", specifier = ">=0.34.0" },
|
||||||
|
{ name = "weaviate-client", specifier = ">=4.16.5" },
|
||||||
]
|
]
|
||||||
provides-extras = ["ui"]
|
provides-extras = ["ui"]
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue