mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-12 13:00:39 +00:00
feat: Implement hybrid search in Milvus (#2644)
Some checks failed
Integration Tests (Replay) / discover-tests (push) Successful in 5s
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 6s
Vector IO Integration Tests / test-matrix (3.12, inline::faiss) (push) Failing after 10s
Test External Providers Installed via Module / test-external-providers-from-module (venv) (push) Has been skipped
Python Package Build Test / build (3.13) (push) Failing after 6s
Vector IO Integration Tests / test-matrix (3.12, inline::milvus) (push) Failing after 9s
Vector IO Integration Tests / test-matrix (3.12, remote::pgvector) (push) Failing after 10s
SqlStore Integration Tests / test-postgres (3.13) (push) Failing after 15s
Vector IO Integration Tests / test-matrix (3.12, inline::sqlite-vec) (push) Failing after 16s
Python Package Build Test / build (3.12) (push) Failing after 10s
SqlStore Integration Tests / test-postgres (3.12) (push) Failing after 21s
Vector IO Integration Tests / test-matrix (3.13, remote::weaviate) (push) Failing after 7s
Vector IO Integration Tests / test-matrix (3.12, remote::chromadb) (push) Failing after 15s
Unit Tests / unit-tests (3.13) (push) Failing after 10s
Vector IO Integration Tests / test-matrix (3.13, inline::sqlite-vec) (push) Failing after 15s
Vector IO Integration Tests / test-matrix (3.13, remote::chromadb) (push) Failing after 12s
Vector IO Integration Tests / test-matrix (3.13, remote::qdrant) (push) Failing after 12s
Vector IO Integration Tests / test-matrix (3.13, inline::faiss) (push) Failing after 8s
Integration Tests (Replay) / Integration Tests (, , , client=, vision=) (push) Failing after 8s
Unit Tests / unit-tests (3.12) (push) Failing after 19s
Vector IO Integration Tests / test-matrix (3.13, remote::pgvector) (push) Failing after 12s
Vector IO Integration Tests / test-matrix (3.12, remote::weaviate) (push) Failing after 11s
Vector IO Integration Tests / test-matrix (3.12, remote::qdrant) (push) Failing after 11s
Test External API and Providers / test-external (venv) (push) Failing after 21s
Vector IO Integration Tests / test-matrix (3.13, inline::milvus) (push) Failing after 19s
Pre-commit / pre-commit (push) Successful in 57s
Some checks failed
Integration Tests (Replay) / discover-tests (push) Successful in 5s
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 6s
Vector IO Integration Tests / test-matrix (3.12, inline::faiss) (push) Failing after 10s
Test External Providers Installed via Module / test-external-providers-from-module (venv) (push) Has been skipped
Python Package Build Test / build (3.13) (push) Failing after 6s
Vector IO Integration Tests / test-matrix (3.12, inline::milvus) (push) Failing after 9s
Vector IO Integration Tests / test-matrix (3.12, remote::pgvector) (push) Failing after 10s
SqlStore Integration Tests / test-postgres (3.13) (push) Failing after 15s
Vector IO Integration Tests / test-matrix (3.12, inline::sqlite-vec) (push) Failing after 16s
Python Package Build Test / build (3.12) (push) Failing after 10s
SqlStore Integration Tests / test-postgres (3.12) (push) Failing after 21s
Vector IO Integration Tests / test-matrix (3.13, remote::weaviate) (push) Failing after 7s
Vector IO Integration Tests / test-matrix (3.12, remote::chromadb) (push) Failing after 15s
Unit Tests / unit-tests (3.13) (push) Failing after 10s
Vector IO Integration Tests / test-matrix (3.13, inline::sqlite-vec) (push) Failing after 15s
Vector IO Integration Tests / test-matrix (3.13, remote::chromadb) (push) Failing after 12s
Vector IO Integration Tests / test-matrix (3.13, remote::qdrant) (push) Failing after 12s
Vector IO Integration Tests / test-matrix (3.13, inline::faiss) (push) Failing after 8s
Integration Tests (Replay) / Integration Tests (, , , client=, vision=) (push) Failing after 8s
Unit Tests / unit-tests (3.12) (push) Failing after 19s
Vector IO Integration Tests / test-matrix (3.13, remote::pgvector) (push) Failing after 12s
Vector IO Integration Tests / test-matrix (3.12, remote::weaviate) (push) Failing after 11s
Vector IO Integration Tests / test-matrix (3.12, remote::qdrant) (push) Failing after 11s
Test External API and Providers / test-external (venv) (push) Failing after 21s
Vector IO Integration Tests / test-matrix (3.13, inline::milvus) (push) Failing after 19s
Pre-commit / pre-commit (push) Successful in 57s
# What does this PR do? This PR implements hybrid search for Milvus DB based on the inbuilt milvus support. To test: ``` pytest tests/unit/providers/vector_io/remote/test_milvus.py -v -s --tb=long --disable-warnings --asyncio-mode=auto ``` Signed-off-by: Varsha Prasad Narsing <varshaprasad96@gmail.com>
This commit is contained in:
parent
5a2d323eca
commit
e3928e6a29
4 changed files with 204 additions and 9 deletions
|
@ -10,7 +10,7 @@ import os
|
|||
from typing import Any
|
||||
|
||||
from numpy.typing import NDArray
|
||||
from pymilvus import DataType, Function, FunctionType, MilvusClient
|
||||
from pymilvus import AnnSearchRequest, DataType, Function, FunctionType, MilvusClient, RRFRanker, WeightedRanker
|
||||
|
||||
from llama_stack.apis.common.errors import VectorStoreNotFoundError
|
||||
from llama_stack.apis.files.files import Files
|
||||
|
@ -27,6 +27,7 @@ from llama_stack.providers.utils.kvstore import kvstore_impl
|
|||
from llama_stack.providers.utils.kvstore.api import KVStore
|
||||
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
RERANKER_TYPE_WEIGHTED,
|
||||
EmbeddingIndex,
|
||||
VectorDBWithIndex,
|
||||
)
|
||||
|
@ -238,7 +239,53 @@ class MilvusIndex(EmbeddingIndex):
|
|||
reranker_type: str,
|
||||
reranker_params: dict[str, Any] | None = None,
|
||||
) -> QueryChunksResponse:
|
||||
raise NotImplementedError("Hybrid search is not supported in Milvus")
|
||||
"""
|
||||
Hybrid search using Milvus's native hybrid search capabilities.
|
||||
|
||||
This implementation uses Milvus's hybrid_search method which combines
|
||||
vector search and BM25 search with configurable reranking strategies.
|
||||
"""
|
||||
search_requests = []
|
||||
|
||||
# nprobe: Controls search accuracy vs performance trade-off
|
||||
# 10 balances these trade-offs for RAG applications
|
||||
search_requests.append(
|
||||
AnnSearchRequest(data=[embedding.tolist()], anns_field="vector", param={"nprobe": 10}, limit=k)
|
||||
)
|
||||
|
||||
# drop_ratio_search: Filters low-importance terms to improve search performance
|
||||
# 0.2 balances noise reduction with recall
|
||||
search_requests.append(
|
||||
AnnSearchRequest(data=[query_string], anns_field="sparse", param={"drop_ratio_search": 0.2}, limit=k)
|
||||
)
|
||||
|
||||
if reranker_type == RERANKER_TYPE_WEIGHTED:
|
||||
alpha = (reranker_params or {}).get("alpha", 0.5)
|
||||
rerank = WeightedRanker(alpha, 1 - alpha)
|
||||
else:
|
||||
impact_factor = (reranker_params or {}).get("impact_factor", 60.0)
|
||||
rerank = RRFRanker(impact_factor)
|
||||
|
||||
search_res = await asyncio.to_thread(
|
||||
self.client.hybrid_search,
|
||||
collection_name=self.collection_name,
|
||||
reqs=search_requests,
|
||||
ranker=rerank,
|
||||
limit=k,
|
||||
output_fields=["chunk_content"],
|
||||
)
|
||||
|
||||
chunks = []
|
||||
scores = []
|
||||
for res in search_res[0]:
|
||||
chunk = Chunk(**res["entity"]["chunk_content"])
|
||||
chunks.append(chunk)
|
||||
scores.append(res["distance"])
|
||||
|
||||
filtered_chunks = [chunk for chunk, score in zip(chunks, scores, strict=False) if score >= score_threshold]
|
||||
filtered_scores = [score for score in scores if score >= score_threshold]
|
||||
|
||||
return QueryChunksResponse(chunks=filtered_chunks, scores=filtered_scores)
|
||||
|
||||
async def delete_chunk(self, chunk_id: str) -> None:
|
||||
"""Remove a chunk from the Milvus collection."""
|
||||
|
|
|
@ -302,23 +302,25 @@ class VectorDBWithIndex:
|
|||
mode = params.get("mode")
|
||||
score_threshold = params.get("score_threshold", 0.0)
|
||||
|
||||
# Get ranker configuration
|
||||
ranker = params.get("ranker")
|
||||
if ranker is None:
|
||||
# Default to RRF with impact_factor=60.0
|
||||
reranker_type = RERANKER_TYPE_RRF
|
||||
reranker_params = {"impact_factor": 60.0}
|
||||
else:
|
||||
reranker_type = ranker.type
|
||||
reranker_params = (
|
||||
{"impact_factor": ranker.impact_factor} if ranker.type == RERANKER_TYPE_RRF else {"alpha": ranker.alpha}
|
||||
)
|
||||
strategy = ranker.get("strategy", "rrf")
|
||||
if strategy == "weighted":
|
||||
weights = ranker.get("params", {}).get("weights", [0.5, 0.5])
|
||||
reranker_type = RERANKER_TYPE_WEIGHTED
|
||||
reranker_params = {"alpha": weights[0] if len(weights) > 0 else 0.5}
|
||||
else:
|
||||
reranker_type = RERANKER_TYPE_RRF
|
||||
k_value = ranker.get("params", {}).get("k", 60.0)
|
||||
reranker_params = {"impact_factor": k_value}
|
||||
|
||||
query_string = interleaved_content_as_str(query)
|
||||
if mode == "keyword":
|
||||
return await self.index.query_keyword(query_string, k, score_threshold)
|
||||
|
||||
# Calculate embeddings for both vector and hybrid modes
|
||||
embeddings_response = await self.inference_api.embeddings(self.vector_db.embedding_model, [query_string])
|
||||
query_vector = np.array(embeddings_response.embeddings[0], dtype=np.float32)
|
||||
if mode == "hybrid":
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue