mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 04:04:14 +00:00
feat: implement keyword and hybrid search for Weaviate provider (#3264)
# What does this PR do? <!-- Provide a short summary of what this PR does and why. Link to relevant issues if applicable. --> - This PR implements keyword and hybrid search for Weaviate DB based on its inbuilt functions. - Added fixtures to conftest.py for Weaviate. - Enabled integration tests for remote Weaviate on all 3 search modes. <!-- If resolving an issue, uncomment and update the line below --> <!-- Closes #[issue-number] --> Closes #3010 ## Test Plan <!-- Describe the tests you ran to verify your changes with result summaries. *Provide clear instructions so the plan can be easily re-executed.* --> Unit tests and integration tests should pass on this PR.
This commit is contained in:
parent
52c8df2322
commit
bcdbb53be3
6 changed files with 242 additions and 48 deletions
|
@ -26,13 +26,15 @@ from llama_stack.providers.remote.vector_io.milvus.milvus import MilvusIndex, Mi
|
|||
from llama_stack.providers.remote.vector_io.pgvector.config import PGVectorVectorIOConfig
|
||||
from llama_stack.providers.remote.vector_io.pgvector.pgvector import PGVectorIndex, PGVectorVectorIOAdapter
|
||||
from llama_stack.providers.remote.vector_io.qdrant.qdrant import QdrantVectorIOAdapter
|
||||
from llama_stack.providers.remote.vector_io.weaviate.config import WeaviateVectorIOConfig
|
||||
from llama_stack.providers.remote.vector_io.weaviate.weaviate import WeaviateIndex, WeaviateVectorIOAdapter
|
||||
|
||||
EMBEDDING_DIMENSION = 384
|
||||
COLLECTION_PREFIX = "test_collection"
|
||||
MILVUS_ALIAS = "test_milvus"
|
||||
|
||||
|
||||
@pytest.fixture(params=["milvus", "sqlite_vec", "faiss", "chroma", "pgvector"])
|
||||
@pytest.fixture(params=["milvus", "sqlite_vec", "faiss", "chroma", "pgvector", "weaviate"])
|
||||
def vector_provider(request):
|
||||
return request.param
|
||||
|
||||
|
@ -448,6 +450,71 @@ async def pgvector_vec_adapter(mock_inference_api, embedding_dimension):
|
|||
await adapter.shutdown()
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def weaviate_vec_db_path(tmp_path_factory):
|
||||
db_path = str(tmp_path_factory.getbasetemp() / "test_weaviate.db")
|
||||
return db_path
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def weaviate_vec_index(weaviate_vec_db_path):
|
||||
import pytest_socket
|
||||
import weaviate
|
||||
|
||||
pytest_socket.enable_socket()
|
||||
client = weaviate.connect_to_embedded(
|
||||
hostname="localhost",
|
||||
port=8080,
|
||||
grpc_port=50051,
|
||||
persistence_data_path=weaviate_vec_db_path,
|
||||
)
|
||||
index = WeaviateIndex(client=client, collection_name="Testcollection")
|
||||
await index.initialize()
|
||||
yield index
|
||||
await index.delete()
|
||||
client.close()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def weaviate_vec_adapter(weaviate_vec_db_path, mock_inference_api, embedding_dimension):
|
||||
import pytest_socket
|
||||
import weaviate
|
||||
|
||||
pytest_socket.enable_socket()
|
||||
|
||||
client = weaviate.connect_to_embedded(
|
||||
hostname="localhost",
|
||||
port=8080,
|
||||
grpc_port=50051,
|
||||
persistence_data_path=weaviate_vec_db_path,
|
||||
)
|
||||
|
||||
config = WeaviateVectorIOConfig(
|
||||
weaviate_cluster_url="localhost:8080",
|
||||
weaviate_api_key=None,
|
||||
kvstore=SqliteKVStoreConfig(),
|
||||
)
|
||||
adapter = WeaviateVectorIOAdapter(
|
||||
config=config,
|
||||
inference_api=mock_inference_api,
|
||||
files_api=None,
|
||||
)
|
||||
collection_id = f"weaviate_test_collection_{random.randint(1, 1_000_000)}"
|
||||
await adapter.initialize()
|
||||
await adapter.register_vector_db(
|
||||
VectorDB(
|
||||
identifier=collection_id,
|
||||
provider_id="test_provider",
|
||||
embedding_model="test_model",
|
||||
embedding_dimension=embedding_dimension,
|
||||
)
|
||||
)
|
||||
adapter.test_collection_id = collection_id
|
||||
yield adapter
|
||||
await adapter.shutdown()
|
||||
client.close()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def vector_io_adapter(vector_provider, request):
|
||||
vector_provider_dict = {
|
||||
|
@ -457,6 +524,7 @@ def vector_io_adapter(vector_provider, request):
|
|||
"chroma": "chroma_vec_adapter",
|
||||
"qdrant": "qdrant_vec_adapter",
|
||||
"pgvector": "pgvector_vec_adapter",
|
||||
"weaviate": "weaviate_vec_adapter",
|
||||
}
|
||||
return request.getfixturevalue(vector_provider_dict[vector_provider])
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue