mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-15 06:00:48 +00:00
[Feat] Implement keyword search in Qdrant
This commit implements keyword search in Qdrant. Signed-off-by: Varsha Prasad Narsing <varshaprasad96@gmail.com>
This commit is contained in:
parent
5b312a80b9
commit
7da7ce9325
4 changed files with 163 additions and 17 deletions
|
@ -134,7 +134,50 @@ class QdrantIndex(EmbeddingIndex):
|
|||
k: int,
|
||||
score_threshold: float,
|
||||
) -> QueryChunksResponse:
|
||||
raise NotImplementedError("Keyword search is not supported in Qdrant")
|
||||
"""Perform keyword search using Qdrant's payload filtering capabilities."""
|
||||
try:
|
||||
results = (
|
||||
await self.client.query_points(
|
||||
collection_name=self.collection_name,
|
||||
query_filter=models.Filter(
|
||||
must=[
|
||||
models.FieldCondition(
|
||||
key="chunk_content.content", match=models.MatchText(text=query_string)
|
||||
)
|
||||
]
|
||||
),
|
||||
limit=k,
|
||||
with_payload=True,
|
||||
with_vectors=False,
|
||||
)
|
||||
).points
|
||||
|
||||
chunks, scores = [], []
|
||||
for point in results:
|
||||
assert isinstance(point, models.ScoredPoint)
|
||||
assert point.payload is not None
|
||||
|
||||
try:
|
||||
chunk = Chunk(**point.payload["chunk_content"])
|
||||
except Exception:
|
||||
log.exception("Failed to parse chunk")
|
||||
continue
|
||||
|
||||
content = chunk.content.lower()
|
||||
query_lower = query_string.lower()
|
||||
if query_lower in content:
|
||||
# Binary scoring: 1.0 for any match, 0.0 for no match
|
||||
score = 1.0
|
||||
|
||||
if score >= score_threshold:
|
||||
chunks.append(chunk)
|
||||
scores.append(score)
|
||||
|
||||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||
|
||||
except Exception as e:
|
||||
log.error(f"Error performing keyword search in Qdrant collection {self.collection_name}: {e}")
|
||||
return QueryChunksResponse(chunks=[], scores=[])
|
||||
|
||||
async def query_hybrid(
|
||||
self,
|
||||
|
|
|
@ -55,6 +55,8 @@ def skip_if_provider_doesnt_support_openai_vector_stores_search(client_with_mode
|
|||
],
|
||||
"keyword": [
|
||||
"inline::sqlite-vec",
|
||||
"inline::qdrant",
|
||||
"remote::qdrant",
|
||||
"remote::milvus",
|
||||
],
|
||||
"hybrid": [
|
||||
|
|
|
@ -4,7 +4,9 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
import random
|
||||
import tempfile
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
@ -17,7 +19,7 @@ from llama_stack.providers.inline.vector_io.chroma.config import ChromaVectorIOC
|
|||
from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig
|
||||
from llama_stack.providers.inline.vector_io.faiss.faiss import FaissIndex, FaissVectorIOAdapter
|
||||
from llama_stack.providers.inline.vector_io.milvus.config import MilvusVectorIOConfig, SqliteKVStoreConfig
|
||||
from llama_stack.providers.inline.vector_io.qdrant import QdrantVectorIOConfig
|
||||
from llama_stack.providers.inline.vector_io.qdrant import QdrantVectorIOConfig as InlineQdrantVectorIOConfig
|
||||
from llama_stack.providers.inline.vector_io.sqlite_vec import SQLiteVectorIOConfig
|
||||
from llama_stack.providers.inline.vector_io.sqlite_vec.sqlite_vec import SQLiteVecIndex, SQLiteVecVectorIOAdapter
|
||||
from llama_stack.providers.remote.vector_io.chroma.chroma import ChromaIndex, ChromaVectorIOAdapter, maybe_await
|
||||
|
@ -29,7 +31,7 @@ COLLECTION_PREFIX = "test_collection"
|
|||
MILVUS_ALIAS = "test_milvus"
|
||||
|
||||
|
||||
@pytest.fixture(params=["milvus", "sqlite_vec", "faiss", "chroma"])
|
||||
@pytest.fixture(params=["milvus", "sqlite_vec", "faiss", "chroma", "qdrant"])
|
||||
def vector_provider(request):
|
||||
return request.param
|
||||
|
||||
|
@ -283,19 +285,22 @@ async def chroma_vec_adapter(chroma_vec_db_path, mock_inference_api, embedding_d
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
def qdrant_vec_db_path(tmp_path_factory):
|
||||
def qdrant_vec_db_path(tmp_path):
|
||||
"""Use tmp_path with additional isolation to ensure unique path per test."""
|
||||
import uuid
|
||||
|
||||
db_path = str(tmp_path_factory.getbasetemp() / f"test_qdrant_{uuid.uuid4()}.db")
|
||||
return db_path
|
||||
# Create a completely isolated temporary directory
|
||||
temp_dir = tempfile.mkdtemp(prefix=f"qdrant_test_{uuid.uuid4()}_")
|
||||
return temp_dir
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def qdrant_vec_adapter(qdrant_vec_db_path, mock_inference_api, embedding_dimension):
|
||||
import shutil
|
||||
import uuid
|
||||
|
||||
config = QdrantVectorIOConfig(
|
||||
db_path=qdrant_vec_db_path,
|
||||
config = InlineQdrantVectorIOConfig(
|
||||
path=qdrant_vec_db_path,
|
||||
kvstore=SqliteKVStoreConfig(),
|
||||
)
|
||||
adapter = QdrantVectorIOAdapter(
|
||||
|
@ -303,20 +308,30 @@ async def qdrant_vec_adapter(qdrant_vec_db_path, mock_inference_api, embedding_d
|
|||
inference_api=mock_inference_api,
|
||||
files_api=None,
|
||||
)
|
||||
collection_id = f"qdrant_test_collection_{uuid.uuid4()}"
|
||||
|
||||
original_initialize = adapter.initialize
|
||||
|
||||
async def safe_initialize():
|
||||
if not hasattr(adapter, "_initialized") or not adapter._initialized:
|
||||
await original_initialize()
|
||||
adapter._initialized = True
|
||||
|
||||
adapter.initialize = safe_initialize
|
||||
await adapter.initialize()
|
||||
await adapter.register_vector_db(
|
||||
VectorDB(
|
||||
identifier=collection_id,
|
||||
provider_id="test_provider",
|
||||
embedding_model="test_model",
|
||||
embedding_dimension=embedding_dimension,
|
||||
)
|
||||
)
|
||||
|
||||
collection_id = f"qdrant_test_collection_{uuid.uuid4()}"
|
||||
adapter.test_collection_id = collection_id
|
||||
adapter._test_db_path = qdrant_vec_db_path
|
||||
yield adapter
|
||||
|
||||
await adapter.shutdown()
|
||||
|
||||
try:
|
||||
if os.path.exists(qdrant_vec_db_path):
|
||||
shutil.rmtree(qdrant_vec_db_path, ignore_errors=True)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def qdrant_vec_index(qdrant_vec_db_path, embedding_dimension):
|
||||
|
|
|
@ -136,3 +136,89 @@ async def test_qdrant_register_and_unregister_vector_db(
|
|||
await qdrant_adapter.unregister_vector_db(vector_db_id)
|
||||
assert not (await qdrant_adapter.client.collection_exists(vector_db_id))
|
||||
assert len((await qdrant_adapter.client.get_collections()).collections) == 0
|
||||
|
||||
|
||||
# Keyword search tests
|
||||
async def test_query_chunks_keyword_search(qdrant_vec_index, sample_chunks, sample_embeddings):
|
||||
"""Test keyword search functionality in Qdrant."""
|
||||
await qdrant_vec_index.add_chunks(sample_chunks, sample_embeddings)
|
||||
query_string = "Sentence 5"
|
||||
response = await qdrant_vec_index.query_keyword(query_string=query_string, k=3, score_threshold=0.0)
|
||||
|
||||
assert isinstance(response, QueryChunksResponse)
|
||||
assert len(response.chunks) > 0, f"Expected some chunks, but got {len(response.chunks)}"
|
||||
|
||||
non_existent_query_str = "blablabla"
|
||||
response_no_results = await qdrant_vec_index.query_keyword(
|
||||
query_string=non_existent_query_str, k=1, score_threshold=0.0
|
||||
)
|
||||
|
||||
assert isinstance(response_no_results, QueryChunksResponse)
|
||||
assert len(response_no_results.chunks) == 0, f"Expected 0 results, but got {len(response_no_results.chunks)}"
|
||||
|
||||
|
||||
async def test_query_chunks_keyword_search_k_greater_than_results(qdrant_vec_index, sample_chunks, sample_embeddings):
|
||||
"""Test keyword search when k is greater than available results."""
|
||||
await qdrant_vec_index.add_chunks(sample_chunks, sample_embeddings)
|
||||
|
||||
query_str = "Sentence 1 from document 0" # Should match only one chunk
|
||||
response = await qdrant_vec_index.query_keyword(k=5, score_threshold=0.0, query_string=query_str)
|
||||
|
||||
assert isinstance(response, QueryChunksResponse)
|
||||
assert 0 < len(response.chunks) < 5, f"Expected results between [1, 4], got {len(response.chunks)}"
|
||||
assert any("Sentence 1 from document 0" in chunk.content for chunk in response.chunks), "Expected chunk not found"
|
||||
|
||||
|
||||
async def test_query_chunks_keyword_search_score_threshold(qdrant_vec_index, sample_chunks, sample_embeddings):
|
||||
"""Test keyword search with score threshold filtering."""
|
||||
await qdrant_vec_index.add_chunks(sample_chunks, sample_embeddings)
|
||||
|
||||
query_string = "Sentence 5"
|
||||
|
||||
# Test with low threshold (should return results)
|
||||
response_low_threshold = await qdrant_vec_index.query_keyword(query_string=query_string, k=3, score_threshold=0.0)
|
||||
assert len(response_low_threshold.chunks) > 0
|
||||
|
||||
# Test with high threshold (should return no results)
|
||||
response_high_threshold = await qdrant_vec_index.query_keyword(
|
||||
query_string=query_string, k=3, score_threshold=1000.0
|
||||
)
|
||||
assert len(response_high_threshold.chunks) == 0
|
||||
|
||||
|
||||
async def test_query_chunks_keyword_search_edge_cases(qdrant_vec_index, sample_chunks, sample_embeddings):
|
||||
"""Test keyword search edge cases."""
|
||||
await qdrant_vec_index.add_chunks(sample_chunks, sample_embeddings)
|
||||
|
||||
# Test with empty string
|
||||
response_empty = await qdrant_vec_index.query_keyword(query_string="", k=3, score_threshold=0.0)
|
||||
assert isinstance(response_empty, QueryChunksResponse)
|
||||
|
||||
# Test with very long query string
|
||||
long_query = "a" * 1000
|
||||
response_long = await qdrant_vec_index.query_keyword(query_string=long_query, k=3, score_threshold=0.0)
|
||||
assert isinstance(response_long, QueryChunksResponse)
|
||||
|
||||
# Test with special characters
|
||||
special_query = "!@#$%^&*()_+-=[]{}|;':\",./<>?"
|
||||
response_special = await qdrant_vec_index.query_keyword(query_string=special_query, k=3, score_threshold=0.0)
|
||||
assert isinstance(response_special, QueryChunksResponse)
|
||||
|
||||
|
||||
async def test_query_chunks_keyword_search_metadata_preservation(
|
||||
qdrant_vec_index, sample_chunks_with_metadata, sample_embeddings_with_metadata
|
||||
):
|
||||
"""Test that keyword search preserves chunk metadata."""
|
||||
await qdrant_vec_index.add_chunks(sample_chunks_with_metadata, sample_embeddings_with_metadata)
|
||||
|
||||
query_string = "Sentence 0"
|
||||
response = await qdrant_vec_index.query_keyword(query_string=query_string, k=2, score_threshold=0.0)
|
||||
|
||||
assert len(response.chunks) > 0
|
||||
for chunk in response.chunks:
|
||||
# Check that metadata is preserved
|
||||
assert hasattr(chunk, "metadata") or hasattr(chunk, "chunk_metadata")
|
||||
if hasattr(chunk, "chunk_metadata") and chunk.chunk_metadata:
|
||||
assert chunk.chunk_metadata.document_id is not None
|
||||
assert chunk.chunk_metadata.chunk_id is not None
|
||||
assert chunk.chunk_metadata.source is not None
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue