mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 09:53:45 +00:00
feat(qdrant): implement hybrid and keyword search support
- Implement hybrid search using Qdrant's native query filtering - Add keyword search support - Update test suites to include qdrant for keyword and hybrid modes Signed-off-by: Varsha Prasad Narsing <varshaprasad96@gmail.com>
This commit is contained in:
parent
e243892ef0
commit
72bef1922c
3 changed files with 228 additions and 3 deletions
|
|
@ -14,17 +14,19 @@ from llama_stack.core.storage.datatypes import KVStoreReference, SqliteKVStoreCo
|
|||
from llama_stack.core.storage.kvstore import register_kvstore_backends
|
||||
from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig
|
||||
from llama_stack.providers.inline.vector_io.faiss.faiss import FaissIndex, FaissVectorIOAdapter
|
||||
from llama_stack.providers.inline.vector_io.qdrant.config import QdrantVectorIOConfig
|
||||
from llama_stack.providers.inline.vector_io.sqlite_vec import SQLiteVectorIOConfig
|
||||
from llama_stack.providers.inline.vector_io.sqlite_vec.sqlite_vec import SQLiteVecIndex, SQLiteVecVectorIOAdapter
|
||||
from llama_stack.providers.remote.vector_io.pgvector.config import PGVectorVectorIOConfig
|
||||
from llama_stack.providers.remote.vector_io.pgvector.pgvector import PGVectorIndex, PGVectorVectorIOAdapter
|
||||
from llama_stack.providers.remote.vector_io.qdrant.qdrant import QdrantIndex, QdrantVectorIOAdapter
|
||||
from llama_stack_api import Chunk, ChunkMetadata, QueryChunksResponse, VectorStore
|
||||
|
||||
EMBEDDING_DIMENSION = 768
|
||||
COLLECTION_PREFIX = "test_collection"
|
||||
|
||||
|
||||
@pytest.fixture(params=["sqlite_vec", "faiss", "pgvector"])
|
||||
@pytest.fixture(params=["sqlite_vec", "faiss", "pgvector", "qdrant"])
|
||||
def vector_provider(request):
|
||||
return request.param
|
||||
|
||||
|
|
@ -317,12 +319,116 @@ async def pgvector_vec_adapter(unique_kvstore_config, mock_inference_api, embedd
|
|||
await adapter.shutdown()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def qdrant_vec_index(embedding_dimension):
|
||||
from qdrant_client import models
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.collection_exists.return_value = False
|
||||
mock_client.create_collection = AsyncMock()
|
||||
mock_client.query_points = AsyncMock(return_value=AsyncMock(points=[]))
|
||||
mock_client.delete_collection = AsyncMock()
|
||||
|
||||
collection_name = f"test-qdrant-collection-{random.randint(1, 1000000)}"
|
||||
index = QdrantIndex(mock_client, collection_name)
|
||||
index._test_chunks = []
|
||||
|
||||
async def mock_add_chunks(chunks, embeddings):
|
||||
index._test_chunks = list(chunks)
|
||||
# Create mock query response with test chunks
|
||||
mock_points = []
|
||||
for chunk in chunks:
|
||||
mock_point = MagicMock(spec=models.ScoredPoint)
|
||||
mock_point.score = 1.0
|
||||
mock_point.payload = {"chunk_content": chunk.model_dump(), "_chunk_id": chunk.chunk_id}
|
||||
mock_points.append(mock_point)
|
||||
|
||||
async def query_points_mock(**kwargs):
|
||||
# Return chunks in order when queried
|
||||
query_k = kwargs.get("limit", len(index._test_chunks))
|
||||
return AsyncMock(points=mock_points[:query_k])
|
||||
|
||||
mock_client.query_points = query_points_mock
|
||||
|
||||
index.add_chunks = mock_add_chunks
|
||||
|
||||
async def mock_query_vector(embedding, k, score_threshold):
|
||||
chunks = index._test_chunks[:k] if hasattr(index, "_test_chunks") else []
|
||||
scores = [1.0] * len(chunks)
|
||||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||
|
||||
index.query_vector = mock_query_vector
|
||||
|
||||
yield index
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def qdrant_vec_adapter(unique_kvstore_config, mock_inference_api, embedding_dimension):
|
||||
config = QdrantVectorIOConfig(
|
||||
path=":memory:",
|
||||
persistence=unique_kvstore_config,
|
||||
)
|
||||
|
||||
adapter = QdrantVectorIOAdapter(config, mock_inference_api, None)
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.collection_exists.return_value = False
|
||||
mock_client.create_collection = AsyncMock()
|
||||
mock_client.query_points = AsyncMock(return_value=AsyncMock(points=[]))
|
||||
mock_client.delete_collection = AsyncMock()
|
||||
mock_client.close = AsyncMock()
|
||||
mock_client.upsert = AsyncMock()
|
||||
|
||||
with patch("llama_stack.providers.remote.vector_io.qdrant.qdrant.AsyncQdrantClient") as mock_client_class:
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
with patch("llama_stack.core.storage.kvstore.kvstore_impl") as mock_kvstore_impl:
|
||||
mock_kvstore = AsyncMock()
|
||||
mock_kvstore.values_in_range.return_value = []
|
||||
mock_kvstore_impl.return_value = mock_kvstore
|
||||
|
||||
with patch.object(adapter, "initialize_openai_vector_stores", new_callable=AsyncMock):
|
||||
await adapter.initialize()
|
||||
adapter.client = mock_client
|
||||
|
||||
async def mock_insert_chunks(vector_store_id, chunks, ttl_seconds=None):
|
||||
index = await adapter._get_and_cache_vector_store_index(vector_store_id)
|
||||
if not index:
|
||||
raise ValueError(f"Vector DB {vector_store_id} not found")
|
||||
await index.insert_chunks(chunks)
|
||||
|
||||
adapter.insert_chunks = mock_insert_chunks
|
||||
|
||||
async def mock_query_chunks(vector_store_id, query, params=None):
|
||||
index = await adapter._get_and_cache_vector_store_index(vector_store_id)
|
||||
if not index:
|
||||
raise ValueError(f"Vector DB {vector_store_id} not found")
|
||||
return await index.query_chunks(query, params)
|
||||
|
||||
adapter.query_chunks = mock_query_chunks
|
||||
|
||||
test_vector_store = VectorStore(
|
||||
identifier=f"qdrant_test_collection_{random.randint(1, 1_000_000)}",
|
||||
provider_id="test_provider",
|
||||
embedding_model="test_model",
|
||||
embedding_dimension=embedding_dimension,
|
||||
)
|
||||
await adapter.register_vector_store(test_vector_store)
|
||||
adapter.test_collection_id = test_vector_store.identifier
|
||||
|
||||
yield adapter
|
||||
await adapter.shutdown()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def vector_io_adapter(vector_provider, request):
|
||||
vector_provider_dict = {
|
||||
"faiss": "faiss_vec_adapter",
|
||||
"sqlite_vec": "sqlite_vec_adapter",
|
||||
"pgvector": "pgvector_vec_adapter",
|
||||
"qdrant": "qdrant_vec_adapter",
|
||||
}
|
||||
return request.getfixturevalue(vector_provider_dict[vector_provider])
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue