mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 04:04:14 +00:00
feat: implement keyword and hybrid search for Weaviate provider (#3264)
# What does this PR do? <!-- Provide a short summary of what this PR does and why. Link to relevant issues if applicable. --> - This PR implements keyword and hybrid search for Weaviate DB based on its inbuilt functions. - Added fixtures to conftest.py for Weaviate. - Enabled integration tests for remote Weaviate on all 3 search modes. <!-- If resolving an issue, uncomment and update the line below --> <!-- Closes #[issue-number] --> Closes #3010 ## Test Plan <!-- Describe the tests you ran to verify your changes with result summaries. *Provide clear instructions so the plan can be easily re-executed.* --> Unit tests and integration tests should pass on this PR.
This commit is contained in:
parent
52c8df2322
commit
bcdbb53be3
6 changed files with 242 additions and 48 deletions
|
@ -22,16 +22,16 @@ def skip_if_provider_doesnt_support_openai_vector_stores(client_with_models):
|
|||
vector_io_providers = [p for p in client_with_models.providers.list() if p.api == "vector_io"]
|
||||
for p in vector_io_providers:
|
||||
if p.provider_type in [
|
||||
"inline::faiss",
|
||||
"inline::sqlite-vec",
|
||||
"inline::milvus",
|
||||
"inline::chromadb",
|
||||
"remote::pgvector",
|
||||
"remote::chromadb",
|
||||
"remote::qdrant",
|
||||
"inline::faiss",
|
||||
"inline::milvus",
|
||||
"inline::qdrant",
|
||||
"remote::weaviate",
|
||||
"inline::sqlite-vec",
|
||||
"remote::chromadb",
|
||||
"remote::milvus",
|
||||
"remote::pgvector",
|
||||
"remote::qdrant",
|
||||
"remote::weaviate",
|
||||
]:
|
||||
return
|
||||
|
||||
|
@ -47,23 +47,25 @@ def skip_if_provider_doesnt_support_openai_vector_stores_search(client_with_mode
|
|||
"inline::milvus",
|
||||
"inline::chromadb",
|
||||
"inline::qdrant",
|
||||
"remote::pgvector",
|
||||
"remote::chromadb",
|
||||
"remote::weaviate",
|
||||
"remote::qdrant",
|
||||
"remote::milvus",
|
||||
"remote::pgvector",
|
||||
"remote::qdrant",
|
||||
"remote::weaviate",
|
||||
],
|
||||
"keyword": [
|
||||
"inline::milvus",
|
||||
"inline::sqlite-vec",
|
||||
"remote::milvus",
|
||||
"inline::milvus",
|
||||
"remote::pgvector",
|
||||
"remote::weaviate",
|
||||
],
|
||||
"hybrid": [
|
||||
"inline::sqlite-vec",
|
||||
"inline::milvus",
|
||||
"inline::sqlite-vec",
|
||||
"remote::milvus",
|
||||
"remote::pgvector",
|
||||
"remote::weaviate",
|
||||
],
|
||||
}
|
||||
supported_providers = search_mode_support.get(search_mode, [])
|
||||
|
|
|
@ -138,8 +138,8 @@ def test_insert_chunks(client_with_empty_registry, embedding_model_id, embedding
|
|||
def test_insert_chunks_with_precomputed_embeddings(client_with_empty_registry, embedding_model_id, embedding_dimension):
|
||||
vector_io_provider_params_dict = {
|
||||
"inline::milvus": {"score_threshold": -1.0},
|
||||
"remote::qdrant": {"score_threshold": -1.0},
|
||||
"inline::qdrant": {"score_threshold": -1.0},
|
||||
"remote::qdrant": {"score_threshold": -1.0},
|
||||
}
|
||||
vector_db_name = "test_precomputed_embeddings_db"
|
||||
register_response = client_with_empty_registry.vector_dbs.register(
|
||||
|
|
|
@ -26,13 +26,15 @@ from llama_stack.providers.remote.vector_io.milvus.milvus import MilvusIndex, Mi
|
|||
from llama_stack.providers.remote.vector_io.pgvector.config import PGVectorVectorIOConfig
|
||||
from llama_stack.providers.remote.vector_io.pgvector.pgvector import PGVectorIndex, PGVectorVectorIOAdapter
|
||||
from llama_stack.providers.remote.vector_io.qdrant.qdrant import QdrantVectorIOAdapter
|
||||
from llama_stack.providers.remote.vector_io.weaviate.config import WeaviateVectorIOConfig
|
||||
from llama_stack.providers.remote.vector_io.weaviate.weaviate import WeaviateIndex, WeaviateVectorIOAdapter
|
||||
|
||||
EMBEDDING_DIMENSION = 384
|
||||
COLLECTION_PREFIX = "test_collection"
|
||||
MILVUS_ALIAS = "test_milvus"
|
||||
|
||||
|
||||
@pytest.fixture(params=["milvus", "sqlite_vec", "faiss", "chroma", "pgvector"])
|
||||
@pytest.fixture(params=["milvus", "sqlite_vec", "faiss", "chroma", "pgvector", "weaviate"])
|
||||
def vector_provider(request):
|
||||
return request.param
|
||||
|
||||
|
@ -448,6 +450,71 @@ async def pgvector_vec_adapter(mock_inference_api, embedding_dimension):
|
|||
await adapter.shutdown()
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def weaviate_vec_db_path(tmp_path_factory):
|
||||
db_path = str(tmp_path_factory.getbasetemp() / "test_weaviate.db")
|
||||
return db_path
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def weaviate_vec_index(weaviate_vec_db_path):
|
||||
import pytest_socket
|
||||
import weaviate
|
||||
|
||||
pytest_socket.enable_socket()
|
||||
client = weaviate.connect_to_embedded(
|
||||
hostname="localhost",
|
||||
port=8080,
|
||||
grpc_port=50051,
|
||||
persistence_data_path=weaviate_vec_db_path,
|
||||
)
|
||||
index = WeaviateIndex(client=client, collection_name="Testcollection")
|
||||
await index.initialize()
|
||||
yield index
|
||||
await index.delete()
|
||||
client.close()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def weaviate_vec_adapter(weaviate_vec_db_path, mock_inference_api, embedding_dimension):
|
||||
import pytest_socket
|
||||
import weaviate
|
||||
|
||||
pytest_socket.enable_socket()
|
||||
|
||||
client = weaviate.connect_to_embedded(
|
||||
hostname="localhost",
|
||||
port=8080,
|
||||
grpc_port=50051,
|
||||
persistence_data_path=weaviate_vec_db_path,
|
||||
)
|
||||
|
||||
config = WeaviateVectorIOConfig(
|
||||
weaviate_cluster_url="localhost:8080",
|
||||
weaviate_api_key=None,
|
||||
kvstore=SqliteKVStoreConfig(),
|
||||
)
|
||||
adapter = WeaviateVectorIOAdapter(
|
||||
config=config,
|
||||
inference_api=mock_inference_api,
|
||||
files_api=None,
|
||||
)
|
||||
collection_id = f"weaviate_test_collection_{random.randint(1, 1_000_000)}"
|
||||
await adapter.initialize()
|
||||
await adapter.register_vector_db(
|
||||
VectorDB(
|
||||
identifier=collection_id,
|
||||
provider_id="test_provider",
|
||||
embedding_model="test_model",
|
||||
embedding_dimension=embedding_dimension,
|
||||
)
|
||||
)
|
||||
adapter.test_collection_id = collection_id
|
||||
yield adapter
|
||||
await adapter.shutdown()
|
||||
client.close()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def vector_io_adapter(vector_provider, request):
|
||||
vector_provider_dict = {
|
||||
|
@ -457,6 +524,7 @@ def vector_io_adapter(vector_provider, request):
|
|||
"chroma": "chroma_vec_adapter",
|
||||
"qdrant": "qdrant_vec_adapter",
|
||||
"pgvector": "pgvector_vec_adapter",
|
||||
"weaviate": "weaviate_vec_adapter",
|
||||
}
|
||||
return request.getfixturevalue(vector_provider_dict[vector_provider])
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue