feat: implement keyword and hybrid search for Weaviate provider

This commit is contained in:
ChristianZaccaria 2025-08-27 12:24:38 +01:00
parent a1301911e4
commit 4541b517c8
8 changed files with 476 additions and 25 deletions

View file

@ -26,13 +26,15 @@ from llama_stack.providers.remote.vector_io.milvus.milvus import MilvusIndex, Mi
from llama_stack.providers.remote.vector_io.pgvector.config import PGVectorVectorIOConfig
from llama_stack.providers.remote.vector_io.pgvector.pgvector import PGVectorIndex, PGVectorVectorIOAdapter
from llama_stack.providers.remote.vector_io.qdrant.qdrant import QdrantVectorIOAdapter
from llama_stack.providers.remote.vector_io.weaviate.config import WeaviateVectorIOConfig
from llama_stack.providers.remote.vector_io.weaviate.weaviate import WeaviateIndex, WeaviateVectorIOAdapter
EMBEDDING_DIMENSION = 384
COLLECTION_PREFIX = "test_collection"
MILVUS_ALIAS = "test_milvus"
@pytest.fixture(params=["milvus", "sqlite_vec", "faiss", "chroma", "pgvector"])
@pytest.fixture(params=["milvus", "sqlite_vec", "faiss", "chroma", "pgvector", "weaviate"])
def vector_provider(request):
return request.param
@ -446,6 +448,75 @@ async def pgvector_vec_adapter(mock_inference_api, embedding_dimension):
yield adapter
await adapter.shutdown()
def weaviate_vec_db_path():
return "localhost:8080"
@pytest.fixture
async def weaviate_vec_index(weaviate_vec_db_path, embedding_dimension):
import uuid
import weaviate
# Connect to local Weaviate instance
client = weaviate.connect_to_local(
host="localhost",
port=8080,
)
collection_name = f"{COLLECTION_PREFIX}_{uuid.uuid4()}"
index = WeaviateIndex(client=client, collection_name=collection_name)
# Create the collection for this test
import weaviate.classes as wvc
from weaviate.collections.classes.config import _CollectionConfig
from llama_stack.providers.utils.vector_io.vector_utils import sanitize_collection_name
sanitized_name = sanitize_collection_name(collection_name, weaviate_format=True)
collection_config = _CollectionConfig(
name=sanitized_name,
vectorizer_config=wvc.config.Configure.Vectorizer.none(),
properties=[
wvc.config.Property(
name="chunk_content",
data_type=wvc.config.DataType.TEXT,
),
],
)
if not client.collections.exists(sanitized_name):
client.collections.create_from_config(collection_config)
yield index
await index.delete()
client.close()
@pytest.fixture
async def weaviate_vec_adapter(weaviate_vec_db_path, mock_inference_api, embedding_dimension):
config = WeaviateVectorIOConfig(
weaviate_cluster_url=weaviate_vec_db_path,
weaviate_api_key=None,
kvstore=SqliteKVStoreConfig(),
)
adapter = WeaviateVectorIOAdapter(
config=config,
inference_api=mock_inference_api,
files_api=None,
)
collection_id = f"weaviate_test_collection_{random.randint(1, 1_000_000)}"
await adapter.initialize()
await adapter.register_vector_db(
VectorDB(
identifier=collection_id,
provider_id="test_provider",
embedding_model="test_model",
embedding_dimension=embedding_dimension,
)
)
adapter.test_collection_id = collection_id
yield adapter
await adapter.shutdown()
@pytest.fixture
@ -457,6 +528,7 @@ def vector_io_adapter(vector_provider, request):
"chroma": "chroma_vec_adapter",
"qdrant": "qdrant_vec_adapter",
"pgvector": "pgvector_vec_adapter",
"weaviate": "weaviate_vec_adapter",
}
return request.getfixturevalue(vector_provider_dict[vector_provider])