feat: implement keyword and hybrid search for Weaviate provider (#3264)

# What does this PR do?
<!-- Provide a short summary of what this PR does and why. Link to
relevant issues if applicable. -->
- This PR implements keyword and hybrid search for Weaviate DB based on
its inbuilt functions.
- Added fixtures to conftest.py for Weaviate.
- Enabled integration tests for remote Weaviate on all 3 search modes.

<!-- If resolving an issue, uncomment and update the line below -->
<!-- Closes #[issue-number] -->
Closes #3010 

## Test Plan
<!-- Describe the tests you ran to verify your changes with result
summaries. *Provide clear instructions so the plan can be easily
re-executed.* -->
Unit tests and integration tests should pass on this PR.
This commit is contained in:
Christian Zaccaria 2025-10-03 09:22:30 +01:00 committed by GitHub
parent 52c8df2322
commit bcdbb53be3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 242 additions and 48 deletions

View file

@ -22,16 +22,16 @@ def skip_if_provider_doesnt_support_openai_vector_stores(client_with_models):
vector_io_providers = [p for p in client_with_models.providers.list() if p.api == "vector_io"]
for p in vector_io_providers:
if p.provider_type in [
"inline::faiss",
"inline::sqlite-vec",
"inline::milvus",
"inline::chromadb",
"remote::pgvector",
"remote::chromadb",
"remote::qdrant",
"inline::faiss",
"inline::milvus",
"inline::qdrant",
"remote::weaviate",
"inline::sqlite-vec",
"remote::chromadb",
"remote::milvus",
"remote::pgvector",
"remote::qdrant",
"remote::weaviate",
]:
return
@ -47,23 +47,25 @@ def skip_if_provider_doesnt_support_openai_vector_stores_search(client_with_mode
"inline::milvus",
"inline::chromadb",
"inline::qdrant",
"remote::pgvector",
"remote::chromadb",
"remote::weaviate",
"remote::qdrant",
"remote::milvus",
"remote::pgvector",
"remote::qdrant",
"remote::weaviate",
],
"keyword": [
"inline::milvus",
"inline::sqlite-vec",
"remote::milvus",
"inline::milvus",
"remote::pgvector",
"remote::weaviate",
],
"hybrid": [
"inline::sqlite-vec",
"inline::milvus",
"inline::sqlite-vec",
"remote::milvus",
"remote::pgvector",
"remote::weaviate",
],
}
supported_providers = search_mode_support.get(search_mode, [])

View file

@ -138,8 +138,8 @@ def test_insert_chunks(client_with_empty_registry, embedding_model_id, embedding
def test_insert_chunks_with_precomputed_embeddings(client_with_empty_registry, embedding_model_id, embedding_dimension):
vector_io_provider_params_dict = {
"inline::milvus": {"score_threshold": -1.0},
"remote::qdrant": {"score_threshold": -1.0},
"inline::qdrant": {"score_threshold": -1.0},
"remote::qdrant": {"score_threshold": -1.0},
}
vector_db_name = "test_precomputed_embeddings_db"
register_response = client_with_empty_registry.vector_dbs.register(

View file

@ -26,13 +26,15 @@ from llama_stack.providers.remote.vector_io.milvus.milvus import MilvusIndex, Mi
from llama_stack.providers.remote.vector_io.pgvector.config import PGVectorVectorIOConfig
from llama_stack.providers.remote.vector_io.pgvector.pgvector import PGVectorIndex, PGVectorVectorIOAdapter
from llama_stack.providers.remote.vector_io.qdrant.qdrant import QdrantVectorIOAdapter
from llama_stack.providers.remote.vector_io.weaviate.config import WeaviateVectorIOConfig
from llama_stack.providers.remote.vector_io.weaviate.weaviate import WeaviateIndex, WeaviateVectorIOAdapter
EMBEDDING_DIMENSION = 384
COLLECTION_PREFIX = "test_collection"
MILVUS_ALIAS = "test_milvus"
@pytest.fixture(params=["milvus", "sqlite_vec", "faiss", "chroma", "pgvector"])
@pytest.fixture(params=["milvus", "sqlite_vec", "faiss", "chroma", "pgvector", "weaviate"])
def vector_provider(request):
return request.param
@ -448,6 +450,71 @@ async def pgvector_vec_adapter(mock_inference_api, embedding_dimension):
await adapter.shutdown()
@pytest.fixture(scope="session")
def weaviate_vec_db_path(tmp_path_factory):
db_path = str(tmp_path_factory.getbasetemp() / "test_weaviate.db")
return db_path
@pytest.fixture
async def weaviate_vec_index(weaviate_vec_db_path):
import pytest_socket
import weaviate
pytest_socket.enable_socket()
client = weaviate.connect_to_embedded(
hostname="localhost",
port=8080,
grpc_port=50051,
persistence_data_path=weaviate_vec_db_path,
)
index = WeaviateIndex(client=client, collection_name="Testcollection")
await index.initialize()
yield index
await index.delete()
client.close()
@pytest.fixture
async def weaviate_vec_adapter(weaviate_vec_db_path, mock_inference_api, embedding_dimension):
import pytest_socket
import weaviate
pytest_socket.enable_socket()
client = weaviate.connect_to_embedded(
hostname="localhost",
port=8080,
grpc_port=50051,
persistence_data_path=weaviate_vec_db_path,
)
config = WeaviateVectorIOConfig(
weaviate_cluster_url="localhost:8080",
weaviate_api_key=None,
kvstore=SqliteKVStoreConfig(),
)
adapter = WeaviateVectorIOAdapter(
config=config,
inference_api=mock_inference_api,
files_api=None,
)
collection_id = f"weaviate_test_collection_{random.randint(1, 1_000_000)}"
await adapter.initialize()
await adapter.register_vector_db(
VectorDB(
identifier=collection_id,
provider_id="test_provider",
embedding_model="test_model",
embedding_dimension=embedding_dimension,
)
)
adapter.test_collection_id = collection_id
yield adapter
await adapter.shutdown()
client.close()
@pytest.fixture
def vector_io_adapter(vector_provider, request):
vector_provider_dict = {
@ -457,6 +524,7 @@ def vector_io_adapter(vector_provider, request):
"chroma": "chroma_vec_adapter",
"qdrant": "qdrant_vec_adapter",
"pgvector": "pgvector_vec_adapter",
"weaviate": "weaviate_vec_adapter",
}
return request.getfixturevalue(vector_provider_dict[vector_provider])